608 lines
30 KiB
Python
608 lines
30 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch PaliGemmamodel."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from ...cache_utils import Cache, StaticCache
|
|
from ...generation import GenerationMixin
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import (
|
|
ModelOutput,
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
is_flash_attn_2_available,
|
|
logging,
|
|
replace_return_docstrings,
|
|
)
|
|
from .configuration_paligemma import PaliGemmaConfig
|
|
|
|
|
|
if is_flash_attn_2_available():
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
from ..auto import AutoModel, AutoModelForCausalLM
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CONFIG_FOR_DOC = "PaliGemmaConfig"
|
|
|
|
|
|
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
|
# But Paligemma has no causal mask on prefix
|
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
attention_mask: torch.Tensor,
|
|
sequence_length: int,
|
|
target_length: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
min_dtype: float,
|
|
cache_position: torch.Tensor,
|
|
batch_size: int,
|
|
is_training: bool = False,
|
|
token_type_ids: torch.Tensor = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
|
|
Args:
|
|
attention_mask (`torch.Tensor`):
|
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
sequence_length (`int`):
|
|
The sequence length being processed.
|
|
target_length (`int`):
|
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
|
dtype (`torch.dtype`):
|
|
The dtype to use for the 4D attention mask.
|
|
device (`torch.device`):
|
|
The device to plcae the 4D attention mask on.
|
|
min_dtype (`float`):
|
|
The minimum value representable with the dtype `dtype`.
|
|
cache_position (`torch.Tensor`):
|
|
Indices depicting the position of the input sequence tokens in the sequence.
|
|
batch_size (`torch.Tensor`):
|
|
Batch size.
|
|
is_training (`bool`):
|
|
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
|
"""
|
|
if attention_mask is not None and attention_mask.dim() == 4:
|
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
causal_mask = attention_mask
|
|
else:
|
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
|
if sequence_length != 1:
|
|
if is_training:
|
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
else:
|
|
causal_mask[:, :sequence_length] = 0.0
|
|
|
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
if attention_mask is not None:
|
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
mask_length = attention_mask.shape[-1]
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
|
padding_mask = padding_mask == 0
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
padding_mask, min_dtype
|
|
)
|
|
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
|
if is_training:
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
|
)
|
|
return causal_mask
|
|
|
|
|
|
@dataclass
|
|
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
|
"""
|
|
Base class for PaliGemmacausal language model (or autoregressive) outputs.
|
|
|
|
Args:
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
image_hidden_states (`torch.FloatTensor`, *optional*):
|
|
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
|
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits: torch.FloatTensor = None
|
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
class PaliGemmaMultiModalProjector(nn.Module):
|
|
def __init__(self, config: PaliGemmaConfig):
|
|
super().__init__()
|
|
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
|
|
|
|
def forward(self, image_features):
|
|
hidden_states = self.linear(image_features)
|
|
|
|
return hidden_states
|
|
|
|
|
|
PALIGEMMA_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]):
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
load the weights associated with the model, only the configuration. Check out the
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
|
PALIGEMMA_START_DOCSTRING,
|
|
)
|
|
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
|
config_class = PaliGemmaConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["PaliGemmaMultiModalProjector"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_cache_class = True
|
|
_supports_quantized_cache = True
|
|
_supports_static_cache = True
|
|
_supports_cache_class = True
|
|
_supports_flash_attn_2 = True
|
|
_supports_sdpa = True
|
|
|
|
def _init_weights(self, module):
|
|
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
|
# inference and fine-tuning
|
|
std = (
|
|
self.config.initializer_range
|
|
if hasattr(self.config, "initializer_range")
|
|
else self.config.text_config.initializer_range
|
|
)
|
|
|
|
if hasattr(module, "class_embedding"):
|
|
module.class_embedding.data.normal_(mean=0.0, std=std)
|
|
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
|
|
|
PALIGEMMA_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
|
The tensors corresponding to the input images. Pixel values can be obtained using
|
|
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
|
|
[`SiglipImageProcessor`] for processing images).
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
`past_key_values`).
|
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
|
information on the default strategy.
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
`past_key_values`).
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
|
the complete sequence length.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""The PALIGEMMA model which consists of a vision backbone and a language model.""",
|
|
PALIGEMMA_START_DOCSTRING,
|
|
)
|
|
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
|
|
def __init__(self, config: PaliGemmaConfig):
|
|
super().__init__(config)
|
|
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
|
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
|
|
self.vocab_size = config.text_config.vocab_size
|
|
|
|
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
|
|
|
|
if language_model._tied_weights_keys is not None:
|
|
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
|
self.language_model = language_model
|
|
|
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
self.post_init()
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
|
|
def get_output_embeddings(self):
|
|
return self.language_model.get_output_embeddings()
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.language_model.set_output_embeddings(new_embeddings)
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
|
|
def set_decoder(self, decoder):
|
|
self.language_model.set_decoder(decoder)
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
|
|
def get_decoder(self):
|
|
return self.language_model.get_decoder()
|
|
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
|
|
def tie_weights(self):
|
|
return self.language_model.tie_weights()
|
|
|
|
def _update_causal_mask(
|
|
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
|
):
|
|
if self.config.text_config._attn_implementation == "flash_attention_2":
|
|
if attention_mask is not None and 0.0 in attention_mask:
|
|
return attention_mask
|
|
return None
|
|
|
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
|
dtype = inputs_embeds.dtype
|
|
min_dtype = torch.finfo(dtype).min
|
|
sequence_length = inputs_embeds.shape[1]
|
|
if using_static_cache:
|
|
target_length = past_key_values.get_max_length()
|
|
else:
|
|
target_length = (
|
|
attention_mask.shape[-1]
|
|
if isinstance(attention_mask, torch.Tensor)
|
|
else cache_position[0] + sequence_length + 1
|
|
)
|
|
|
|
if attention_mask is not None and attention_mask.dim() == 4:
|
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
return attention_mask
|
|
|
|
causal_mask = torch.full(
|
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
|
)
|
|
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
|
if sequence_length != 1:
|
|
if is_training:
|
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
else:
|
|
causal_mask[:, :sequence_length] = 0.0
|
|
|
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
|
if attention_mask is not None:
|
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
mask_length = attention_mask.shape[-1]
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
|
padding_mask = padding_mask == 0
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
padding_mask, min_dtype
|
|
)
|
|
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
|
if is_training:
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
|
)
|
|
return causal_mask
|
|
|
|
def get_image_features(self, pixel_values: torch.FloatTensor):
|
|
"""
|
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
|
The tensors corresponding to the input images.
|
|
Returns:
|
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
"""
|
|
image_outputs = self.vision_tower(pixel_values)
|
|
selected_image_feature = image_outputs.last_hidden_state
|
|
image_features = self.multi_modal_projector(selected_image_feature)
|
|
image_features = image_features / (self.config.hidden_size**0.5)
|
|
return image_features
|
|
|
|
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
num_logits_to_keep: int = 0,
|
|
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
|
|
r"""
|
|
Args:
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
|
|
num_logits_to_keep (`int`, *optional*):
|
|
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
|
|
|
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
|
|
>>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
|
|
|
|
>>> prompt = "answer en Where is the cow standing?"
|
|
>>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(**inputs, max_length=30)
|
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"answer en Where is the cow standing?\nbeach"
|
|
```"""
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if pixel_values is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
)
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
is_training = token_type_ids is not None and labels is not None
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
|
|
|
# Merge text and images
|
|
if pixel_values is not None:
|
|
image_features = self.get_image_features(pixel_values)
|
|
|
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
|
|
raise ValueError(
|
|
f"Number of images does not match number of special image tokens in the input text. "
|
|
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
"tokens from image embeddings."
|
|
)
|
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
|
|
# mask out pad-token-ids in labels for BC
|
|
if labels is not None and self.pad_token_id in labels:
|
|
logger.warning_once(
|
|
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. ",
|
|
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
)
|
|
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
|
|
|
causal_mask = self._update_causal_mask(
|
|
attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training
|
|
)
|
|
|
|
outputs = self.language_model(
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
num_logits_to_keep=num_logits_to_keep,
|
|
)
|
|
|
|
logits = outputs.logits
|
|
loss = None
|
|
if labels is not None:
|
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
logits = logits.float()
|
|
shift_logits = logits[..., :-1, :]
|
|
shift_labels = labels[..., 1:]
|
|
if attention_mask is not None:
|
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
|
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
|
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
|
else:
|
|
shift_logits = shift_logits.contiguous()
|
|
shift_labels = shift_labels.contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
loss = loss_fct(flat_logits, flat_labels)
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
return PaliGemmaCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
cache_position=None,
|
|
position_ids=None,
|
|
pixel_values=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
use_cache=True,
|
|
num_logits_to_keep=None,
|
|
**kwargs,
|
|
):
|
|
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
cache_position=cache_position,
|
|
use_cache=use_cache,
|
|
num_logits_to_keep=num_logits_to_keep,
|
|
token_type_ids=token_type_ids,
|
|
**kwargs,
|
|
)
|
|
|
|
# position_ids in Paligemma are 1-indexed
|
|
if model_inputs.get("position_ids") is not None:
|
|
model_inputs["position_ids"] += 1
|
|
|
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
|
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
|
if cache_position[0] == 0:
|
|
model_inputs["pixel_values"] = pixel_values
|
|
|
|
return model_inputs
|