508 lines
22 KiB
Python
508 lines
22 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch Pixtral model."""
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from ... import PreTrainedModel
|
|
from ...activations import ACT2FN
|
|
from ...modeling_outputs import BaseModelOutput
|
|
from ...utils import (
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
logging,
|
|
)
|
|
from .configuration_pixtral import PixtralVisionConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def position_ids_in_meshgrid(patch_embeds_list, max_width):
|
|
positions = []
|
|
for patch in patch_embeds_list:
|
|
height, width = patch.shape[-2:]
|
|
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
|
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
|
|
ids = h_grid * max_width + v_grid
|
|
positions.append(ids[:, 0])
|
|
return torch.cat(positions)
|
|
|
|
|
|
class PixtralRotaryEmbedding(nn.Module):
|
|
"""
|
|
The key with pixtral embedding is just that you have a frequency for each pixel positions.
|
|
If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
|
|
is given by indexing the pre_computed frequency on the width and height.
|
|
|
|
What you output is of dimension (batch, height * width, dim) with dim the embed dim.
|
|
|
|
This simply means that for each image hidden state, you are going to add
|
|
a corresponding positional embedding, based on its index in the grid.
|
|
"""
|
|
|
|
def __init__(self, config, device):
|
|
super().__init__()
|
|
self.rope_type = "default"
|
|
self.dim = config.head_dim
|
|
self.base = config.rope_theta
|
|
max_patches_per_side = config.image_size // config.patch_size
|
|
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
|
|
|
h = torch.arange(max_patches_per_side, device=freqs.device)
|
|
w = torch.arange(max_patches_per_side, device=freqs.device)
|
|
|
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
|
inv_freq = torch.cat(
|
|
[
|
|
freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
|
|
freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
|
|
],
|
|
dim=-1,
|
|
).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
|
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
|
|
# TODO maybe make it torch compatible later on. We can also just slice
|
|
self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x, position_ids):
|
|
if "dynamic" in self.rope_type:
|
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
|
|
|
# Core RoPE block
|
|
freqs = self.inv_freq[position_ids]
|
|
# position_ids_expanded = position_ids[:, None, :].float()
|
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
|
device_type = x.device.type
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False):
|
|
emb = freqs
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
def _dynamic_frequency_update(self, position_ids, device):
|
|
"""
|
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
|
1 - growing beyond the cached sequence length (allow scaling)
|
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
|
"""
|
|
seq_len = torch.max(position_ids) + 1
|
|
if seq_len > self.max_seq_len_cached: # growth
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
|
self.max_seq_len_cached = seq_len
|
|
|
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
self.max_seq_len_cached = self.original_max_seq_len
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
position_ids (`torch.Tensor`, *optional*):
|
|
Deprecated and unused.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
|
|
class PixtralAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
|
|
self.scale = self.head_dim**-0.5
|
|
self.dropout = config.attention_dropout
|
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
batch_size, patches, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
|
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(batch_size, patches, -1)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
|
|
class PixtralMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
def forward(self, hidden_state):
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
|
|
class PixtralRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
"""
|
|
PixtralRMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
def extra_repr(self):
|
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
|
|
|
|
class PixtralAttentionLayer(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
|
self.feed_forward = PixtralMLP(config)
|
|
self.attention = PixtralAttention(config)
|
|
self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.FloatTensor]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`):
|
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
|
attention_mask (`torch.FloatTensor`):
|
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.attention_norm(hidden_states)
|
|
hidden_states, attn_weights = self.attention(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.ffn_norm(hidden_states)
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
return outputs
|
|
|
|
|
|
class PixtralTransformer(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(config.num_hidden_layers):
|
|
self.layers.append(PixtralAttentionLayer(config))
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutput]:
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Embeddings which serve as input to the Transformer.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
hidden_states = inputs_embeds
|
|
for encoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
encoder_layer.__call__,
|
|
hidden_states,
|
|
attention_mask,
|
|
position_embeddings,
|
|
output_attentions,
|
|
)
|
|
else:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions
|
|
)
|
|
|
|
|
|
PIXTRAL_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`PixtralVisionConfig`]):
|
|
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
|
|
load the weights associated with the model, only the configuration. Check out the
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
|
|
class PixtralPreTrainedModel(PreTrainedModel):
|
|
config_class = PixtralVisionConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["PixtralVisionAttention"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_cache_class = True
|
|
|
|
def _init_weights(self, module):
|
|
std = (
|
|
self.config.initializer_range
|
|
if hasattr(self.config, "initializer_range")
|
|
else self.config.initializer_range
|
|
)
|
|
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
|
|
|
PIXTRAL_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
|
|
for details.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
def generate_block_attention_mask(patch_embeds_list, tensor):
|
|
dtype = tensor.dtype
|
|
device = tensor.device
|
|
seq_len = tensor.shape[1]
|
|
d_min = torch.finfo(dtype).min
|
|
causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
|
|
|
|
block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
|
|
block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
|
|
for start, end in zip(block_start_idx, block_end_idx):
|
|
causal_mask[start:end, start:end] = 0
|
|
|
|
causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
|
|
return causal_mask
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.",
|
|
PIXTRAL_START_DOCSTRING,
|
|
)
|
|
class PixtralVisionModel(PixtralPreTrainedModel):
|
|
base_model_prefix = "vision_encoder"
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=config.hidden_size,
|
|
kernel_size=config.patch_size,
|
|
stride=config.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
|
self.transformer = PixtralTransformer(config)
|
|
self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device)
|
|
|
|
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
pixel_values: List[torch.Tensor],
|
|
output_hidden_states: Optional[bool] = False,
|
|
output_attentions: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
*args,
|
|
**kwargs,
|
|
) -> Union[Tuple, BaseModelOutput]:
|
|
"""
|
|
Returns:
|
|
pixel_values: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
position_ids = position_ids_in_meshgrid(
|
|
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
|
).to(self.device)
|
|
|
|
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
|
attention_mask = generate_block_attention_mask(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
|
)
|
|
return self.transformer(patch_embeds, attention_mask, position_embedding)
|