1319 lines
63 KiB
Python
1319 lines
63 KiB
Python
|
|
# coding=utf-8
|
||
|
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
"""PyTorch Idefics3 model."""
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.utils.checkpoint
|
||
|
|
from torch import nn
|
||
|
|
from torch.nn import CrossEntropyLoss
|
||
|
|
|
||
|
|
from ... import PreTrainedModel
|
||
|
|
from ...activations import ACT2FN
|
||
|
|
from ...cache_utils import Cache, DynamicCache
|
||
|
|
from ...generation import GenerationMixin
|
||
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||
|
|
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||
|
|
from ...utils import (
|
||
|
|
add_start_docstrings,
|
||
|
|
add_start_docstrings_to_model_forward,
|
||
|
|
is_flash_attn_2_available,
|
||
|
|
is_flash_attn_greater_or_equal_2_10,
|
||
|
|
logging,
|
||
|
|
replace_return_docstrings,
|
||
|
|
)
|
||
|
|
from ..auto import AutoModel
|
||
|
|
from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
||
|
|
|
||
|
|
|
||
|
|
if is_flash_attn_2_available():
|
||
|
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||
|
|
|
||
|
|
|
||
|
|
logger = logging.get_logger(__name__)
|
||
|
|
|
||
|
|
_CONFIG_FOR_DOC = "Idefics3Config"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Idefics3BaseModelOutputWithPast(ModelOutput):
|
||
|
|
"""
|
||
|
|
Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
||
|
|
Args:
|
||
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
||
|
|
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||
|
|
hidden_size)` is output.
|
||
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||
|
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||
|
|
encoder_sequence_length, embed_size_per_head)`.
|
||
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||
|
|
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||
|
|
input) to speed up sequential decoding.
|
||
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||
|
|
sequence_length)`.
|
||
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
|
heads.
|
||
|
|
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
||
|
|
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
||
|
|
sequence_length, hidden_size)`.
|
||
|
|
image_hidden_states of the model produced by the vision encoder
|
||
|
|
"""
|
||
|
|
|
||
|
|
last_hidden_state: torch.FloatTensor = None
|
||
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Idefics3CausalLMOutputWithPast(ModelOutput):
|
||
|
|
"""
|
||
|
|
Base class for Idefics causal language model (or autoregressive) outputs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||
|
|
Language modeling loss (for next-token prediction).
|
||
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||
|
|
`past_key_values` input) to speed up sequential decoding.
|
||
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||
|
|
sequence_length)`.
|
||
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
|
heads.
|
||
|
|
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
||
|
|
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
||
|
|
sequence_length, hidden_size)`.
|
||
|
|
image_hidden_states of the model produced by the vision encoder
|
||
|
|
"""
|
||
|
|
|
||
|
|
loss: Optional[torch.FloatTensor] = None
|
||
|
|
logits: torch.FloatTensor = None
|
||
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
|
||
|
|
class Idefics3VisionEmbeddings(nn.Module):
|
||
|
|
"""
|
||
|
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||
|
|
resolution.
|
||
|
|
|
||
|
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||
|
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||
|
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||
|
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, config: Idefics3VisionConfig):
|
||
|
|
super().__init__()
|
||
|
|
self.embed_dim = config.hidden_size
|
||
|
|
self.image_size = config.image_size
|
||
|
|
self.patch_size = config.patch_size
|
||
|
|
|
||
|
|
self.patch_embedding = nn.Conv2d(
|
||
|
|
in_channels=config.num_channels,
|
||
|
|
out_channels=self.embed_dim,
|
||
|
|
kernel_size=self.patch_size,
|
||
|
|
stride=self.patch_size,
|
||
|
|
padding="valid",
|
||
|
|
)
|
||
|
|
|
||
|
|
self.num_patches_per_side = self.image_size // self.patch_size
|
||
|
|
self.num_patches = self.num_patches_per_side**2
|
||
|
|
self.num_positions = self.num_patches
|
||
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||
|
|
|
||
|
|
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
|
||
|
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||
|
|
|
||
|
|
patch_embeds = self.patch_embedding(pixel_values)
|
||
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||
|
|
|
||
|
|
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
|
||
|
|
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
|
||
|
|
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
|
||
|
|
|
||
|
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||
|
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||
|
|
nb_patches_w = p_attn_mask[0].sum()
|
||
|
|
|
||
|
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||
|
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||
|
|
|
||
|
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
||
|
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
||
|
|
|
||
|
|
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
|
||
|
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||
|
|
|
||
|
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||
|
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||
|
|
return embeddings
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
|
||
|
|
class Idefics3VisionAttention(nn.Module):
|
||
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
|
||
|
|
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__()
|
||
|
|
self.config = config
|
||
|
|
self.embed_dim = config.hidden_size
|
||
|
|
self.num_heads = config.num_attention_heads
|
||
|
|
self.head_dim = self.embed_dim // self.num_heads
|
||
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||
|
|
raise ValueError(
|
||
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||
|
|
f" {self.num_heads})."
|
||
|
|
)
|
||
|
|
self.scale = self.head_dim**-0.5
|
||
|
|
self.dropout = config.attention_dropout
|
||
|
|
|
||
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
|
|
||
|
|
# Ignore copy
|
||
|
|
self.is_causal = False
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
|
output_attentions: Optional[bool] = False,
|
||
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
|
|
"""Input shape: Batch x Time x Channel"""
|
||
|
|
|
||
|
|
batch_size, q_len, _ = hidden_states.size()
|
||
|
|
|
||
|
|
query_states = self.q_proj(hidden_states)
|
||
|
|
key_states = self.k_proj(hidden_states)
|
||
|
|
value_states = self.v_proj(hidden_states)
|
||
|
|
|
||
|
|
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
|
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
|
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
|
|
||
|
|
k_v_seq_len = key_states.shape[-2]
|
||
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||
|
|
|
||
|
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||
|
|
raise ValueError(
|
||
|
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||
|
|
f" {attn_weights.size()}"
|
||
|
|
)
|
||
|
|
|
||
|
|
if attention_mask is not None:
|
||
|
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||
|
|
raise ValueError(
|
||
|
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||
|
|
)
|
||
|
|
attn_weights = attn_weights + attention_mask
|
||
|
|
|
||
|
|
# upcast attention to fp32
|
||
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
||
|
|
|
||
|
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
||
|
|
raise ValueError(
|
||
|
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
||
|
|
f" {attn_output.size()}"
|
||
|
|
)
|
||
|
|
|
||
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||
|
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||
|
|
|
||
|
|
attn_output = self.out_proj(attn_output)
|
||
|
|
|
||
|
|
return attn_output, attn_weights
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2 with Idefics2->Idefics3
|
||
|
|
class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
|
||
|
|
"""
|
||
|
|
Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays
|
||
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||
|
|
def __init__(self, *args, **kwargs):
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
|
||
|
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||
|
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||
|
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
||
|
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
|
past_key_value: Optional[Cache] = None,
|
||
|
|
output_attentions: bool = False,
|
||
|
|
use_cache: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
|
output_attentions = False
|
||
|
|
|
||
|
|
bsz, q_len, _ = hidden_states.size()
|
||
|
|
|
||
|
|
query_states = self.q_proj(hidden_states)
|
||
|
|
key_states = self.k_proj(hidden_states)
|
||
|
|
value_states = self.v_proj(hidden_states)
|
||
|
|
|
||
|
|
# Flash attention requires the input to have the shape
|
||
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
||
|
|
# therefore we just need to keep the original shape
|
||
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
|
||
|
|
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
|
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
|
|
||
|
|
kv_seq_len = key_states.shape[-2]
|
||
|
|
if past_key_value is not None:
|
||
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||
|
|
|
||
|
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||
|
|
# to be able to avoid many of these transpose/reshape/view.
|
||
|
|
key_states = key_states.transpose(1, 2)
|
||
|
|
value_states = value_states.transpose(1, 2)
|
||
|
|
|
||
|
|
dropout_rate = self.dropout if self.training else 0.0
|
||
|
|
|
||
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||
|
|
# in fp32. (Idefics3VisionRMSNorm handles it correctly)
|
||
|
|
|
||
|
|
input_dtype = query_states.dtype
|
||
|
|
if input_dtype == torch.float32:
|
||
|
|
if torch.is_autocast_enabled():
|
||
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||
|
|
# Handle the case where the model is quantized
|
||
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||
|
|
target_dtype = self.config._pre_quantization_dtype
|
||
|
|
else:
|
||
|
|
target_dtype = self.q_proj.weight.dtype
|
||
|
|
|
||
|
|
logger.warning_once(
|
||
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||
|
|
f" {target_dtype}."
|
||
|
|
)
|
||
|
|
|
||
|
|
query_states = query_states.to(target_dtype)
|
||
|
|
key_states = key_states.to(target_dtype)
|
||
|
|
value_states = value_states.to(target_dtype)
|
||
|
|
|
||
|
|
attn_output = _flash_attention_forward(
|
||
|
|
query_states,
|
||
|
|
key_states,
|
||
|
|
value_states,
|
||
|
|
attention_mask,
|
||
|
|
q_len,
|
||
|
|
dropout=dropout_rate,
|
||
|
|
is_causal=self.is_causal,
|
||
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||
|
|
)
|
||
|
|
|
||
|
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
|
||
|
|
attn_output = self.out_proj(attn_output)
|
||
|
|
|
||
|
|
if not output_attentions:
|
||
|
|
attn_weights = None
|
||
|
|
|
||
|
|
return attn_output, attn_weights
|
||
|
|
|
||
|
|
|
||
|
|
IDEFICS_VISION_ATTENTION_CLASSES = {
|
||
|
|
"eager": Idefics3VisionAttention,
|
||
|
|
"flash_attention_2": Idefics3VisionFlashAttention2,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
|
||
|
|
class Idefics3VisionMLP(nn.Module):
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__()
|
||
|
|
self.config = config
|
||
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
|
|
||
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
|
hidden_states = self.fc1(hidden_states)
|
||
|
|
hidden_states = self.activation_fn(hidden_states)
|
||
|
|
hidden_states = self.fc2(hidden_states)
|
||
|
|
return hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
class Idefics3SimpleMLP(nn.Module):
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__()
|
||
|
|
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
|
||
|
|
output_size = config.text_config.hidden_size
|
||
|
|
self.proj = nn.Linear(input_size, output_size, bias=False)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.proj(x)
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
|
||
|
|
class Idefics3EncoderLayer(nn.Module):
|
||
|
|
def __init__(self, config: Idefics3VisionConfig):
|
||
|
|
super().__init__()
|
||
|
|
self.embed_dim = config.hidden_size
|
||
|
|
self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
|
||
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||
|
|
self.mlp = Idefics3VisionMLP(config)
|
||
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||
|
|
|
||
|
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
attention_mask: torch.Tensor,
|
||
|
|
output_attentions: Optional[bool] = False,
|
||
|
|
) -> Tuple[torch.FloatTensor]:
|
||
|
|
"""
|
||
|
|
Args:
|
||
|
|
hidden_states (`torch.FloatTensor`):
|
||
|
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||
|
|
attention_mask (`torch.FloatTensor`):
|
||
|
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||
|
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
|
returned tensors for more detail.
|
||
|
|
"""
|
||
|
|
residual = hidden_states
|
||
|
|
|
||
|
|
hidden_states = self.layer_norm1(hidden_states)
|
||
|
|
hidden_states, attn_weights = self.self_attn(
|
||
|
|
hidden_states=hidden_states,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
)
|
||
|
|
hidden_states = residual + hidden_states
|
||
|
|
|
||
|
|
residual = hidden_states
|
||
|
|
hidden_states = self.layer_norm2(hidden_states)
|
||
|
|
hidden_states = self.mlp(hidden_states)
|
||
|
|
hidden_states = residual + hidden_states
|
||
|
|
|
||
|
|
outputs = (hidden_states,)
|
||
|
|
|
||
|
|
if output_attentions:
|
||
|
|
outputs += (attn_weights,)
|
||
|
|
|
||
|
|
return outputs
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
|
||
|
|
class Idefics3Encoder(nn.Module):
|
||
|
|
"""
|
||
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||
|
|
[`Idefics3EncoderLayer`].
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config: Idefics3Config
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, config: Idefics3Config):
|
||
|
|
super().__init__()
|
||
|
|
self.config = config
|
||
|
|
self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||
|
|
self.gradient_checkpointing = False
|
||
|
|
|
||
|
|
# Ignore copy
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
inputs_embeds,
|
||
|
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
|
output_attentions: Optional[bool] = None,
|
||
|
|
output_hidden_states: Optional[bool] = None,
|
||
|
|
return_dict: Optional[bool] = None,
|
||
|
|
) -> Union[Tuple, BaseModelOutput]:
|
||
|
|
r"""
|
||
|
|
Args:
|
||
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||
|
|
than the model's internal embedding lookup matrix.
|
||
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
|
||
|
|
- 1 for tokens that are **not masked**,
|
||
|
|
- 0 for tokens that are **masked**.
|
||
|
|
|
||
|
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
|
output_attentions (`bool`, *optional*):
|
||
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
|
returned tensors for more detail.
|
||
|
|
output_hidden_states (`bool`, *optional*):
|
||
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||
|
|
for more detail.
|
||
|
|
return_dict (`bool`, *optional*):
|
||
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
|
"""
|
||
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
|
output_hidden_states = (
|
||
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
|
)
|
||
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
|
||
|
|
encoder_states = () if output_hidden_states else None
|
||
|
|
all_attentions = () if output_attentions else None
|
||
|
|
|
||
|
|
hidden_states = inputs_embeds
|
||
|
|
for encoder_layer in self.layers:
|
||
|
|
if output_hidden_states:
|
||
|
|
encoder_states = encoder_states + (hidden_states,)
|
||
|
|
if self.gradient_checkpointing and self.training:
|
||
|
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
|
encoder_layer.__call__,
|
||
|
|
hidden_states,
|
||
|
|
attention_mask,
|
||
|
|
output_attentions,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
layer_outputs = encoder_layer(
|
||
|
|
hidden_states,
|
||
|
|
attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
)
|
||
|
|
|
||
|
|
hidden_states = layer_outputs[0]
|
||
|
|
|
||
|
|
if output_attentions:
|
||
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||
|
|
|
||
|
|
if output_hidden_states:
|
||
|
|
encoder_states = encoder_states + (hidden_states,)
|
||
|
|
|
||
|
|
if not return_dict:
|
||
|
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||
|
|
return BaseModelOutput(
|
||
|
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||
|
|
"""
|
||
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||
|
|
"""
|
||
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||
|
|
if n_rep == 1:
|
||
|
|
return hidden_states
|
||
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||
|
|
|
||
|
|
|
||
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
|
||
|
|
class Idefics3RMSNorm(nn.Module):
|
||
|
|
def __init__(self, hidden_size, eps=1e-6):
|
||
|
|
"""
|
||
|
|
Idefics3RMSNorm is equivalent to T5LayerNorm
|
||
|
|
"""
|
||
|
|
super().__init__()
|
||
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||
|
|
self.variance_epsilon = eps
|
||
|
|
|
||
|
|
def forward(self, hidden_states):
|
||
|
|
input_dtype = hidden_states.dtype
|
||
|
|
hidden_states = hidden_states.to(torch.float32)
|
||
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||
|
|
return self.weight * hidden_states.to(input_dtype)
|
||
|
|
|
||
|
|
def extra_repr(self):
|
||
|
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||
|
|
|
||
|
|
|
||
|
|
class Idefics3Connector(nn.Module):
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__()
|
||
|
|
self.scale_factor = config.scale_factor
|
||
|
|
self.modality_projection = Idefics3SimpleMLP(config)
|
||
|
|
|
||
|
|
def pixel_shuffle(self, x, scale_factor=2):
|
||
|
|
bsz, seq, embed_dim = x.size()
|
||
|
|
height = width = int(seq**0.5)
|
||
|
|
x = x.view(bsz, height, width, embed_dim)
|
||
|
|
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
|
||
|
|
x = x.permute(0, 2, 1, 3)
|
||
|
|
x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
|
||
|
|
x = x.permute(0, 2, 1, 3)
|
||
|
|
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
||
|
|
return x
|
||
|
|
|
||
|
|
def forward(self, image_hidden_states):
|
||
|
|
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
||
|
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||
|
|
return image_hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
IDEFICS3_START_DOCSTRING = r"""
|
||
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||
|
|
etc.)
|
||
|
|
|
||
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||
|
|
and behavior.
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
config ([`Idefics3Config`] or [`Idefics3VisionConfig`]):
|
||
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||
|
|
load the weights associated with the model, only the configuration. Check out the
|
||
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
@add_start_docstrings(
|
||
|
|
"The bare Idefics3 Model outputting raw hidden-states without any specific head on top.",
|
||
|
|
IDEFICS3_START_DOCSTRING,
|
||
|
|
)
|
||
|
|
class Idefics3PreTrainedModel(PreTrainedModel):
|
||
|
|
config_class = Idefics3Config
|
||
|
|
base_model_prefix = "model"
|
||
|
|
supports_gradient_checkpointing = True
|
||
|
|
_no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
|
||
|
|
_skip_keys_device_placement = "past_key_values"
|
||
|
|
_supports_flash_attn_2 = True
|
||
|
|
_supports_sdpa = True
|
||
|
|
_supports_cache_class = True
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
|
||
|
|
def _init_weights(self, module):
|
||
|
|
std = (
|
||
|
|
self.config.text_config.initializer_range
|
||
|
|
if hasattr(self.config, "initializer_range")
|
||
|
|
else self.config.text_config.initializer_range
|
||
|
|
)
|
||
|
|
|
||
|
|
if hasattr(module, "class_embedding"):
|
||
|
|
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||
|
|
|
||
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
||
|
|
if module.bias is not None:
|
||
|
|
module.bias.data.zero_()
|
||
|
|
elif isinstance(module, nn.Embedding):
|
||
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
||
|
|
if module.padding_idx is not None:
|
||
|
|
module.weight.data[module.padding_idx].zero_()
|
||
|
|
|
||
|
|
|
||
|
|
IDEFICS3_VISION_START_DOCSTRING = r"""
|
||
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||
|
|
etc.)
|
||
|
|
|
||
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||
|
|
and behavior.
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
config ([`Idefics3VisionConfig`]):
|
||
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||
|
|
load the weights associated with the model, only the configuration. Check out the
|
||
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
@add_start_docstrings(
|
||
|
|
"The Idefics3 Vision Transformer Model outputting raw image embedding.",
|
||
|
|
IDEFICS3_VISION_START_DOCSTRING,
|
||
|
|
)
|
||
|
|
class Idefics3VisionTransformer(Idefics3PreTrainedModel):
|
||
|
|
config_class = Idefics3VisionConfig
|
||
|
|
_supports_sdpa = False
|
||
|
|
|
||
|
|
def __init__(self, config: Idefics3VisionConfig):
|
||
|
|
super().__init__(config)
|
||
|
|
embed_dim = config.hidden_size
|
||
|
|
|
||
|
|
self.embeddings = Idefics3VisionEmbeddings(config)
|
||
|
|
self.encoder = Idefics3Encoder(config)
|
||
|
|
self.patch_size = config.patch_size
|
||
|
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||
|
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
|
||
|
|
def get_input_embeddings(self):
|
||
|
|
return self.embeddings
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
|
||
|
|
def set_input_embeddings(self, value):
|
||
|
|
self.embeddings = value
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
pixel_values,
|
||
|
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||
|
|
output_attentions: Optional[bool] = None,
|
||
|
|
output_hidden_states: Optional[bool] = None,
|
||
|
|
return_dict: Optional[bool] = None,
|
||
|
|
) -> Union[Tuple, BaseModelOutput]:
|
||
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
|
output_hidden_states = (
|
||
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
|
)
|
||
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
|
||
|
|
batch_size = pixel_values.size(0)
|
||
|
|
if patch_attention_mask is None:
|
||
|
|
patch_size = self.patch_size
|
||
|
|
patch_attention_mask = torch.ones(
|
||
|
|
(
|
||
|
|
batch_size,
|
||
|
|
pixel_values.size(2) // patch_size,
|
||
|
|
pixel_values.size(3) // patch_size,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
|
||
|
|
|
||
|
|
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||
|
|
|
||
|
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||
|
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||
|
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||
|
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||
|
|
if not torch.any(~patch_attention_mask):
|
||
|
|
patch_attention_mask = None
|
||
|
|
elif not self._use_flash_attention_2:
|
||
|
|
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
||
|
|
|
||
|
|
encoder_outputs = self.encoder(
|
||
|
|
inputs_embeds=hidden_states,
|
||
|
|
attention_mask=patch_attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
output_hidden_states=output_hidden_states,
|
||
|
|
return_dict=return_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
last_hidden_state = encoder_outputs[0]
|
||
|
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||
|
|
|
||
|
|
if not return_dict:
|
||
|
|
return (last_hidden_state,) + encoder_outputs[1:]
|
||
|
|
|
||
|
|
return BaseModelOutput(
|
||
|
|
last_hidden_state=last_hidden_state,
|
||
|
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
|
attentions=encoder_outputs.attentions,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
IDEFICS3_INPUTS_DOCSTRING = r"""
|
||
|
|
Args:
|
||
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||
|
|
it.
|
||
|
|
|
||
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
|
||
|
|
[What are input IDs?](../glossary#input-ids)
|
||
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
|
||
|
|
- 1 for tokens that are **not masked**,
|
||
|
|
- 0 for tokens that are **masked**.
|
||
|
|
|
||
|
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
|
|
||
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
|
||
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||
|
|
`past_key_values`).
|
||
|
|
|
||
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||
|
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||
|
|
information on the default strategy.
|
||
|
|
|
||
|
|
- 1 indicates the head is **not masked**,
|
||
|
|
- 0 indicates the head is **masked**.
|
||
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||
|
|
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||
|
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||
|
|
|
||
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||
|
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
|
|
||
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||
|
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||
|
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||
|
|
model's internal embedding lookup matrix.
|
||
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||
|
|
The tensors corresponding to the input images. Pixel values can be obtained using
|
||
|
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
|
||
|
|
[`CLIPImageProcessor`] for processing images).
|
||
|
|
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
|
||
|
|
Mask to avoid performing attention on padding pixel indices.
|
||
|
|
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||
|
|
The hidden states of the image encoder after modality projection.
|
||
|
|
use_cache (`bool`, *optional*):
|
||
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||
|
|
`past_key_values`).
|
||
|
|
output_attentions (`bool`, *optional*):
|
||
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
|
tensors for more detail.
|
||
|
|
output_hidden_states (`bool`, *optional*):
|
||
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
|
more detail.
|
||
|
|
return_dict (`bool`, *optional*):
|
||
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
@add_start_docstrings(
|
||
|
|
"""Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder""",
|
||
|
|
IDEFICS3_START_DOCSTRING,
|
||
|
|
)
|
||
|
|
class Idefics3Model(Idefics3PreTrainedModel):
|
||
|
|
def __init__(self, config: Idefics3Config):
|
||
|
|
super().__init__(config)
|
||
|
|
self.padding_idx = self.config.text_config.pad_token_id
|
||
|
|
self.vocab_size = self.config.text_config.vocab_size
|
||
|
|
|
||
|
|
self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
|
||
|
|
self.connector = Idefics3Connector(config)
|
||
|
|
self.text_model = AutoModel.from_config(config.text_config)
|
||
|
|
|
||
|
|
self.image_seq_len = int(
|
||
|
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
|
||
|
|
)
|
||
|
|
self.image_token_id = self.config.image_token_id
|
||
|
|
|
||
|
|
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
|
||
|
|
|
||
|
|
self.post_init()
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
|
||
|
|
def enable_input_require_grads(self):
|
||
|
|
"""
|
||
|
|
Enables the gradients for the input embeddings.
|
||
|
|
|
||
|
|
This is useful for lora when using gradient checkpointing.
|
||
|
|
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
|
||
|
|
|
||
|
|
Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def get_lowest_module(module):
|
||
|
|
if len(list(module.children())) == 0:
|
||
|
|
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
|
||
|
|
return module
|
||
|
|
else:
|
||
|
|
# Recursively call the function on each child module
|
||
|
|
return get_lowest_module(list(module.children())[0])
|
||
|
|
|
||
|
|
def make_inputs_require_grads(module, input, output):
|
||
|
|
output.requires_grad_(True)
|
||
|
|
|
||
|
|
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
|
||
|
|
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
|
||
|
|
make_inputs_require_grads
|
||
|
|
)
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
|
||
|
|
def disable_input_require_grads(self):
|
||
|
|
self._text_require_grads_hook.remove()
|
||
|
|
self._vision_require_grads_hook.remove()
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
|
||
|
|
def get_input_embeddings(self):
|
||
|
|
return self.text_model.get_input_embeddings()
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
|
||
|
|
def set_input_embeddings(self, value):
|
||
|
|
self.text_model.set_input_embeddings(value)
|
||
|
|
|
||
|
|
def inputs_merger(
|
||
|
|
self,
|
||
|
|
input_ids: torch.LongTensor,
|
||
|
|
inputs_embeds: Optional[torch.Tensor],
|
||
|
|
image_hidden_states: Optional[torch.Tensor],
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
|
||
|
|
The merging happens as follows:
|
||
|
|
- The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
|
||
|
|
- We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
|
||
|
|
We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
|
||
|
|
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||
|
|
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||
|
|
"""
|
||
|
|
num_images, _, vision_hidden_size = image_hidden_states.shape
|
||
|
|
special_image_token_mask = input_ids == self.image_token_id
|
||
|
|
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||
|
|
new_inputs_embeds = inputs_embeds.clone()
|
||
|
|
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
|
||
|
|
# cast to the dtype of the input_embeds to support quantized models
|
||
|
|
reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.dtype)
|
||
|
|
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
||
|
|
return new_inputs_embeds
|
||
|
|
|
||
|
|
@add_start_docstrings_to_model_forward(
|
||
|
|
"""
|
||
|
|
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
|
||
|
|
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
|
||
|
|
max_num_images is the maximum number of images among the batch_size samples in the batch.
|
||
|
|
Padding images are not needed beyond padding the pixel_values at the entrance of the model.
|
||
|
|
For efficiency, we only pass through the vision_model's forward the real images by
|
||
|
|
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
|
||
|
|
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
|
||
|
|
""",
|
||
|
|
IDEFICS3_INPUTS_DOCSTRING,
|
||
|
|
)
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
input_ids: torch.LongTensor = None,
|
||
|
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||
|
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
|
use_cache: Optional[bool] = None,
|
||
|
|
output_attentions: Optional[bool] = None,
|
||
|
|
output_hidden_states: Optional[bool] = None,
|
||
|
|
return_dict: Optional[bool] = None,
|
||
|
|
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
|
||
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
|
output_hidden_states = (
|
||
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
|
)
|
||
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
|
||
|
|
if self.training and self.text_model.gradient_checkpointing and use_cache:
|
||
|
|
logger.warning_once(
|
||
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||
|
|
)
|
||
|
|
use_cache = False
|
||
|
|
|
||
|
|
# retrieve input_ids and inputs_embeds
|
||
|
|
if input_ids is not None:
|
||
|
|
batch_size, seq_length = input_ids.shape
|
||
|
|
elif inputs_embeds is not None:
|
||
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||
|
|
else:
|
||
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||
|
|
|
||
|
|
past_seen_tokens = 0
|
||
|
|
if use_cache:
|
||
|
|
if past_key_values is None:
|
||
|
|
past_key_values = DynamicCache()
|
||
|
|
past_seen_tokens = past_key_values.get_seq_length()
|
||
|
|
|
||
|
|
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||
|
|
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||
|
|
|
||
|
|
if inputs_embeds is None:
|
||
|
|
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
|
||
|
|
|
||
|
|
# START VISUAL INPUTS INTEGRATION
|
||
|
|
if pixel_values is not None and image_hidden_states is not None:
|
||
|
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||
|
|
elif pixel_values is not None:
|
||
|
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||
|
|
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||
|
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||
|
|
|
||
|
|
# Remove padding images - padding images are full 0.
|
||
|
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||
|
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||
|
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||
|
|
|
||
|
|
# Handle the vision attention mask
|
||
|
|
if pixel_attention_mask is None:
|
||
|
|
pixel_attention_mask = torch.ones(
|
||
|
|
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||
|
|
dtype=torch.bool,
|
||
|
|
device=pixel_values.device,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Remove padding images from the mask
|
||
|
|
pixel_attention_mask = pixel_attention_mask.view(
|
||
|
|
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||
|
|
)
|
||
|
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||
|
|
|
||
|
|
patch_size = self.config.vision_config.patch_size
|
||
|
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||
|
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||
|
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||
|
|
|
||
|
|
# Get sequence from the vision encoder
|
||
|
|
image_hidden_states = self.vision_model(
|
||
|
|
pixel_values=pixel_values,
|
||
|
|
patch_attention_mask=patch_attention_mask,
|
||
|
|
).last_hidden_state
|
||
|
|
|
||
|
|
# Modality projection & resampling
|
||
|
|
image_hidden_states = self.connector(image_hidden_states)
|
||
|
|
|
||
|
|
elif image_hidden_states is not None:
|
||
|
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||
|
|
|
||
|
|
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
|
||
|
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||
|
|
# that simply don't exist
|
||
|
|
inputs_embeds = self.inputs_merger(
|
||
|
|
input_ids=input_ids,
|
||
|
|
inputs_embeds=inputs_embeds,
|
||
|
|
image_hidden_states=image_hidden_states,
|
||
|
|
)
|
||
|
|
|
||
|
|
outputs = self.text_model(
|
||
|
|
inputs_embeds=inputs_embeds,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
position_ids=position_ids,
|
||
|
|
past_key_values=past_key_values,
|
||
|
|
use_cache=use_cache,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
output_hidden_states=output_hidden_states,
|
||
|
|
return_dict=return_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not return_dict:
|
||
|
|
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
|
||
|
|
|
||
|
|
return Idefics3BaseModelOutputWithPast(
|
||
|
|
last_hidden_state=outputs.last_hidden_state,
|
||
|
|
past_key_values=outputs.past_key_values,
|
||
|
|
hidden_states=outputs.hidden_states,
|
||
|
|
attentions=outputs.attentions,
|
||
|
|
image_hidden_states=image_hidden_states,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@add_start_docstrings(
|
||
|
|
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
|
||
|
|
IDEFICS3_START_DOCSTRING,
|
||
|
|
)
|
||
|
|
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
|
||
|
|
_tied_weights_keys = ["lm_head.weight"]
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__(config)
|
||
|
|
self.model = Idefics3Model(config)
|
||
|
|
self.image_token_id = self.config.image_token_id
|
||
|
|
|
||
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||
|
|
self.vocab_size = config.text_config.vocab_size
|
||
|
|
|
||
|
|
# Initialize weights and apply final processing
|
||
|
|
self.post_init()
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
|
||
|
|
def enable_input_require_grads(self):
|
||
|
|
"""
|
||
|
|
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
||
|
|
the model weights fixed.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def make_inputs_require_grads(module, input, output):
|
||
|
|
output.requires_grad_(True)
|
||
|
|
|
||
|
|
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
|
||
|
|
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
|
||
|
|
make_inputs_require_grads
|
||
|
|
)
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
|
||
|
|
def disable_input_require_grads(self):
|
||
|
|
self._text_require_grads_hook.remove()
|
||
|
|
self._vision_require_grads_hook.remove()
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
|
||
|
|
def get_input_embeddings(self):
|
||
|
|
return self.model.text_model.get_input_embeddings()
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
|
||
|
|
def set_input_embeddings(self, value):
|
||
|
|
self.model.text_model.set_input_embeddings(value)
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_output_embeddings
|
||
|
|
def get_output_embeddings(self):
|
||
|
|
return self.lm_head
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_output_embeddings
|
||
|
|
def set_output_embeddings(self, new_embeddings):
|
||
|
|
self.lm_head = new_embeddings
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.tie_weights
|
||
|
|
def tie_weights(self):
|
||
|
|
"""
|
||
|
|
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
|
||
|
|
"""
|
||
|
|
output_embeddings = self.get_output_embeddings()
|
||
|
|
input_embeddings = self.get_input_embeddings()
|
||
|
|
|
||
|
|
if getattr(self.config, "tie_word_embeddings", True):
|
||
|
|
output_embeddings.weight = input_embeddings.weight
|
||
|
|
|
||
|
|
@add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING)
|
||
|
|
@replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
input_ids: torch.LongTensor = None,
|
||
|
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||
|
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
|
labels: Optional[torch.LongTensor] = None,
|
||
|
|
use_cache: Optional[bool] = None,
|
||
|
|
output_attentions: Optional[bool] = None,
|
||
|
|
output_hidden_states: Optional[bool] = None,
|
||
|
|
return_dict: Optional[bool] = None,
|
||
|
|
) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
|
||
|
|
r"""
|
||
|
|
Args:
|
||
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||
|
|
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
|
||
|
|
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
|
||
|
|
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||
|
|
Returns:
|
||
|
|
|
||
|
|
Example:
|
||
|
|
|
||
|
|
```python
|
||
|
|
>>> import requests
|
||
|
|
>>> import torch
|
||
|
|
>>> from PIL import Image
|
||
|
|
>>> from io import BytesIO
|
||
|
|
|
||
|
|
>>> from transformers import AutoProcessor, AutoModelForVision2Seq
|
||
|
|
>>> from transformers.image_utils import load_image
|
||
|
|
|
||
|
|
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
|
||
|
|
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
||
|
|
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
|
||
|
|
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
|
||
|
|
|
||
|
|
>>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
|
||
|
|
>>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", torch_dtype=torch.bfloat16, device_map="auto")
|
||
|
|
|
||
|
|
>>> # Create inputs
|
||
|
|
>>> messages = [
|
||
|
|
... {
|
||
|
|
... "role": "user",
|
||
|
|
... "content": [
|
||
|
|
... {"type": "image"},
|
||
|
|
... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
|
||
|
|
... {"type": "image"},
|
||
|
|
... {"type": "text", "text": "What can we see in this image?"},
|
||
|
|
... ]
|
||
|
|
... },
|
||
|
|
... {
|
||
|
|
... "role": "user",
|
||
|
|
... "content": [
|
||
|
|
... {"type": "image"},
|
||
|
|
... {"type": "text", "text": "In which city is that bridge located?"},
|
||
|
|
... ]
|
||
|
|
... }
|
||
|
|
... ]
|
||
|
|
|
||
|
|
>>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
|
||
|
|
>>> images = [[image1, image2], [image3]]
|
||
|
|
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
|
||
|
|
|
||
|
|
>>> # Generate
|
||
|
|
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
|
||
|
|
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||
|
|
|
||
|
|
>>> print(generated_texts[0])
|
||
|
|
Assistant: There are buildings, trees, lights, and water visible in this image.
|
||
|
|
|
||
|
|
>>> print(generated_texts[1])
|
||
|
|
Assistant: The bridge is in San Francisco.
|
||
|
|
```"""
|
||
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
|
output_hidden_states = (
|
||
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
|
)
|
||
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
|
||
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||
|
|
outputs = self.model(
|
||
|
|
input_ids=input_ids,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
position_ids=position_ids,
|
||
|
|
past_key_values=past_key_values,
|
||
|
|
inputs_embeds=inputs_embeds,
|
||
|
|
pixel_values=pixel_values,
|
||
|
|
pixel_attention_mask=pixel_attention_mask,
|
||
|
|
image_hidden_states=image_hidden_states,
|
||
|
|
use_cache=use_cache,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
output_hidden_states=output_hidden_states,
|
||
|
|
return_dict=return_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
hidden_states = outputs[0]
|
||
|
|
logits = self.lm_head(hidden_states)
|
||
|
|
|
||
|
|
loss = None
|
||
|
|
if labels is not None:
|
||
|
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||
|
|
logits = logits.float()
|
||
|
|
labels = labels.to(logits.device)
|
||
|
|
# Shift so that tokens < n predict n
|
||
|
|
if attention_mask is not None:
|
||
|
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||
|
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||
|
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||
|
|
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
||
|
|
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
||
|
|
else:
|
||
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
||
|
|
shift_labels = labels[..., 1:].contiguous()
|
||
|
|
# Flatten the tokens
|
||
|
|
loss_fct = CrossEntropyLoss()
|
||
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||
|
|
|
||
|
|
if not return_dict:
|
||
|
|
output = (logits,) + outputs[1:]
|
||
|
|
return (loss,) + output if loss is not None else output
|
||
|
|
|
||
|
|
return Idefics3CausalLMOutputWithPast(
|
||
|
|
loss=loss,
|
||
|
|
logits=logits,
|
||
|
|
past_key_values=outputs.past_key_values,
|
||
|
|
hidden_states=outputs.hidden_states,
|
||
|
|
attentions=outputs.attentions,
|
||
|
|
image_hidden_states=outputs.image_hidden_states,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
|
||
|
|
def prepare_inputs_for_generation(
|
||
|
|
self,
|
||
|
|
input_ids,
|
||
|
|
past_key_values=None,
|
||
|
|
attention_mask=None,
|
||
|
|
inputs_embeds=None,
|
||
|
|
cache_position=None,
|
||
|
|
pixel_values=None,
|
||
|
|
pixel_attention_mask=None,
|
||
|
|
image_hidden_states=None,
|
||
|
|
num_logits_to_keep=None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
# Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
|
||
|
|
# precedence is moved to the model, we can remove this fn)
|
||
|
|
|
||
|
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||
|
|
if past_key_values is not None:
|
||
|
|
if inputs_embeds is not None: # Exception 1
|
||
|
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||
|
|
elif input_ids.shape[1] != cache_position.shape[0]:
|
||
|
|
input_ids = input_ids[:, cache_position]
|
||
|
|
|
||
|
|
position_ids = kwargs.get("position_ids", None)
|
||
|
|
if attention_mask is not None and position_ids is None:
|
||
|
|
# create position_ids on the fly for batch generation
|
||
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||
|
|
if past_key_values:
|
||
|
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||
|
|
|
||
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||
|
|
# but IDEFICS requires noth ids and embeds to be present
|
||
|
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||
|
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
|
||
|
|
else:
|
||
|
|
# The clone here is for the same reason as for `position_ids`.
|
||
|
|
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||
|
|
|
||
|
|
if num_logits_to_keep is not None:
|
||
|
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||
|
|
|
||
|
|
if image_hidden_states is not None:
|
||
|
|
pixel_values = None
|
||
|
|
pixel_attention_mask = None
|
||
|
|
else:
|
||
|
|
pixel_values = pixel_values
|
||
|
|
pixel_attention_mask = pixel_attention_mask
|
||
|
|
model_inputs.update(
|
||
|
|
{
|
||
|
|
"position_ids": position_ids,
|
||
|
|
"past_key_values": past_key_values,
|
||
|
|
"use_cache": kwargs.get("use_cache"),
|
||
|
|
"attention_mask": attention_mask,
|
||
|
|
"pixel_values": pixel_values,
|
||
|
|
"pixel_attention_mask": pixel_attention_mask,
|
||
|
|
"image_hidden_states": image_hidden_states,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return model_inputs
|
||
|
|
|
||
|
|
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
|
||
|
|
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
||
|
|
model_kwargs = super()._update_model_kwargs_for_generation(
|
||
|
|
outputs=outputs,
|
||
|
|
model_kwargs=model_kwargs,
|
||
|
|
is_encoder_decoder=is_encoder_decoder,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
# Get the precomputed image_hidden_states
|
||
|
|
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||
|
|
return model_kwargs
|