Diffusers-BAGEL / pipeline.py
para-lost's picture
fix device
c387723
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os, sys, importlib
from typing import Optional, Dict, List
import torch
from functools import partial
from diffusers import DiffusionPipeline
from diffusers.utils import logging
from accelerate import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from huggingface_hub import snapshot_download
from tqdm import tqdm
from copy import deepcopy
import random
import cv2
import numpy as np
from torchvision import transforms
from torchvision.transforms import functional as TF
from torchvision.transforms import InterpolationMode
from dataclasses import dataclass
from types import SimpleNamespace
from einops import rearrange
from torch import Tensor, nn
from safetensors.torch import load_file as load_sft
import copy
from typing import List, Tuple, Optional
import torch.nn.functional as F
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from dataclasses import asdict, fields
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin
import math
from transformers.activations import ACT2FN
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.functional import scaled_dot_product_attention
from transformers.utils import ModelOutput
from flash_attn import flash_attn_varlen_func
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
flex_attention = torch.compile(flex_attention)
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from typing import List, Optional, Tuple, Union
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from typing import Optional, Tuple
from transformers.tokenization_utils import AddedToken
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
import json
import unicodedata
from functools import lru_cache
import regex as re
from transformers.tokenization_utils import PreTrainedTokenizer
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
import string
import warnings
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.convert_slow_tokenizer import import_protobuf
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import AddedToken
if TYPE_CHECKING:
from transformers.tokenization_utils_base import TextInput
from transformers.utils import logging, requires_backends
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
SPIECE_UNDERLINE = "▁"
from typing import Dict, List, Optional, Union
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import (
convert_to_rgb,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
logger = logging.get_logger(__name__)
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
torch_int,
)
from typing import List, Optional, Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType
from PIL import Image
from torch.nn.attention.flex_attention import or_masks, and_masks
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
def remove_noise_mask(b, h, q_idx, kv_idx):
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
value = i if model in ['full', 'noise'] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if model == 'noise' else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
def patchify(image, patch_size):
p = patch_size
c, h, w = image.shape
assert h % p == 0 and w % p == 0
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p**2 * c)
return image
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
return pos_ids
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len = sum(split_lens)
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
assert attn_mode in ['causal', 'full', 'noise']
if attn_mode == "causal":
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
attention_mask[csum:csum + s, :csum] = 1
else:
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
attention_mask[csum:csum + s, :csum] = 1
csum += s
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "noise":
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
csum += s
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
~attention_mask, float("-inf")
)
return attention_mask
def split_integer_exp_decay(S, ng_sample_decay=1.0):
if ng_sample_decay == 1.0:
N = random.randint(1, S)
else:
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
return result, cumsum
def pil_img2rgb(image):
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
image = image.convert("RGBA")
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
white.paste(image, mask=image.split()[3])
image = white
else:
image = image.convert("RGB")
return image
def add_special_tokens(tokenizer):
all_special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
if '<|im_start|>' not in all_special_tokens:
new_tokens.append('<|im_start|>')
if '<|im_end|>' not in all_special_tokens:
new_tokens.append('<|im_end|>')
if '<|vision_start|>' not in all_special_tokens:
new_tokens.append('<|vision_start|>')
if '<|vision_end|>' not in all_special_tokens:
new_tokens.append('<|vision_end|>')
num_new_tokens = tokenizer.add_tokens(new_tokens)
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
new_token_ids = dict(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
start_of_image=start_of_image,
end_of_image=end_of_image,
)
return tokenizer, new_token_ids, num_new_tokens
def len2weight(x, loss_reduction='square'):
if x == 0:
return x
if loss_reduction == 'token':
return 1
if loss_reduction == 'sample':
return 1 / x
if loss_reduction == 'square':
return 1 / (x ** 0.5)
raise NotImplementedError(loss_reduction)
class NaiveCache:
def __init__(self, num_layers):
self.key_cache = {k: None for k in range(num_layers)}
self.value_cache = {k: None for k in range(num_layers)}
@property
def num_layers(self):
return len(self.key_cache)
@property
def seq_lens(self):
if self.key_cache[0] is not None:
return self.key_cache[0].shape[0]
else:
return 0
class _Qwen2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, _Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = _Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
is_causal=True,
_attn_implementation="flash_attention_2",
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
self.is_causal = is_causal
self._attn_implementation = _attn_implementation
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
_CONFIG_FOR_DOC = "_Qwen2Config"
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[_Qwen2Config] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Qwen2Attention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: _Qwen2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = config.is_causal
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Qwen2FlashAttention2(Qwen2Attention):
"""
Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
}
QWEN2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`_Qwen2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
QWEN2_START_DOCSTRING,
)
class Qwen2PreTrainedModel(PreTrainedModel):
config_class = _Qwen2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
QWEN2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"tokenizer_file": "tokenizer.json",
}
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
@lru_cache()
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Qwen2Tokenizer(PreTrainedTokenizer):
"""
Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```python
>>> from transformers import Qwen2Tokenizer
>>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
>>> tokenizer("Hello world")["input_ids"]
[9707, 1879]
>>> tokenizer(" Hello world")["input_ids"]
[21927, 1879]
```
This is expected.
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
clean_up_tokenization_spaces=False,
split_special_tokens=False,
**kwargs,
):
# Qwen vocab does not contain control tokens; added tokens need to be special
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_merges = []
with open(merges_file, encoding="utf-8") as merges_handle:
for i, line in enumerate(merges_handle):
line = line.strip()
if (i == 0 and line.startswith("#version:")) or not line:
continue
bpe_merges.append(tuple(line.split()))
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# NOTE: the cache can grow without bound and will get really large for long running processes
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
# not a memory leak but appears as one.
# GPT2Tokenizer has the same problem, so let's be consistent.
self.cache = {}
self.pat = re.compile(PRETOKENIZE_REGEX)
if kwargs.get("add_prefix_space", False):
logger.warning_once(
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
)
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
split_special_tokens=split_special_tokens,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
# and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
return super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
text = unicodedata.normalize("NFC", text)
return (text, kwargs)
class SiglipTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`SiglipModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
max_position_embeddings (`int`, *optional*, defaults to 64):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
Example:
```python
>>> from transformers import SiglipTextConfig, SiglipTextModel
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_text_model"
def __init__(
self,
vocab_size=32000,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=64,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
# This differs from `CLIPTokenizer`'s default and from openai/siglip
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
pad_token_id=1,
bos_token_id=49406,
eos_token_id=49407,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["text_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class _SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import _SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a _SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = _SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class SiglipConfig(PretrainedConfig):
r"""
[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`SiglipTextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`_SiglipVisionConfig`].
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import SiglipConfig, SiglipModel
>>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipConfig()
>>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a _SiglipVisionConfig
>>> from transformers import SiglipTextConfig, _SiglipVisionConfig
>>> # Initializing a SiglipText and SiglipVision configuration
>>> config_text = SiglipTextConfig()
>>> config_vision = _SiglipVisionConfig()
>>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
```"""
model_type = "siglip"
def __init__(self, text_config=None, vision_config=None, **kwargs):
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. initializing the `_SiglipVisionConfig` with default values.")
self.text_config = SiglipTextConfig(**text_config)
self.vision_config = _SiglipVisionConfig(**vision_config)
self.initializer_factor = 1.0
@classmethod
def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: _SiglipVisionConfig, **kwargs):
r"""
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
model configuration.
Returns:
[`SiglipConfig`]: An instance of a configuration object
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
if is_vision_available():
import PIL
class SiglipImageProcessor(BaseImageProcessor):
r"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
height, width = size["height"], size["width"]
images = [
resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
# General docstring
_CONFIG_FOR_DOC = "SiglipConfig"
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsequently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
class SiglipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
class SiglipTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class SiglipOutput(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`SiglipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipSdpaAttention(SiglipAttention):
"""
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
is_causal = False
# Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if self.is_causal and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
_no_split_modules = [
"SiglipTextEmbeddings",
"SiglipEncoderLayer",
"SiglipVisionEmbeddings",
"SiglipEncoderLayer",
"SiglipMultiheadAttentionPoolingHead",
]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, SiglipConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, SiglipForImageClassification):
nn.init.normal_(
module.classifier.weight,
std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class SiglipTextTransformer(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, embed_dim)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
if attention_mask is not None and not self._use_flash_attention_2:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# Assuming "sticky" EOS tokenization, last token is always EOS.
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""The text model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipTextModel(SiglipPreTrainedModel):
config_class = SiglipTextConfig
def __init__(self, config: SiglipTextConfig):
super().__init__(config)
self.text_model = SiglipTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, SiglipTextModel
>>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: _SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@add_start_docstrings(SIGLIP_START_DOCSTRING)
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
def __init__(self, config: SiglipConfig):
super().__init__(config)
if not isinstance(config.text_config, SiglipTextConfig):
raise TypeError(
"config.text_config is expected to be of type SiglipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, _SiglipVisionConfig):
raise TypeError(
"config.vision_config is expected to be of type _SiglipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
# First, initialize the text and vision models with proper attention implementation
text_model = SiglipTextModel._from_config(text_config)
vision_model = SiglipVisionModel._from_config(vision_config)
# Second, get the text and vision submodules (for backward compatibility)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`SiglipTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... text_features = model.get_text_features(**inputs)
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`SiglipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```"""
# Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooled_output = vision_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, SiglipOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> # important: we pass `padding=max_length` since the model was trained with this
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
text_embeds = text_outputs[1]
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = (
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
+ self.logit_bias
)
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return SiglipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@add_start_docstrings(
"""
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
""",
SIGLIP_START_DOCSTRING,
)
class SiglipForImageClassification(SiglipPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: SiglipConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
# Create the vision model with proper attention
# and take only vision_model submodule (for backward compatibility)
vision_model = SiglipVisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
# Classifier head
self.classifier = (
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, SiglipForImageClassification
>>> import torch
>>> from PIL import Image
>>> import requests
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # note: we are loading a `SiglipModel` from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
>>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the two classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: LABEL_1
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vision_model(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = outputs[0]
# average pool the patch tokens
sequence_output = torch.mean(sequence_output, dim=1)
# apply classifier
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class SiglipProcessor(ProcessorMixin):
r"""
Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
[`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
[`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
Args:
image_processor ([`SiglipImageProcessor`]):
The image processor is a required input.
tokenizer ([`SiglipTokenizer`]):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = "SiglipTokenizer"
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` argument to
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
if text is not None:
encoding = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
)
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors)
if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
else:
return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
class SiglipTokenizer(PreTrainedTokenizer):
"""
Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"</s>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
model_max_length (`int`, *optional*, defaults to 64):
The maximum length (in number of tokens) for model inputs.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="</s>",
additional_special_tokens=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
model_max_length=64,
do_lower_case=True,
**kwargs,
) -> None:
requires_backends(self, "protobuf")
pad_token = (
AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
if isinstance(pad_token, str)
else pad_token
)
unk_token = (
AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
if isinstance(unk_token, str)
else unk_token
)
eos_token = (
AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
if isinstance(eos_token, str)
else eos_token
)
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.do_lower_case = do_lower_case
self.vocab_file = vocab_file
self.sp_model = self.get_spm_processor()
self.vocab_file = vocab_file
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
model_max_length=model_max_length,
do_lower_case=do_lower_case,
**kwargs,
)
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
@property
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
def vocab_size(self):
return self.sp_model.get_piece_size()
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
# normal case: some special tokens
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
" eos tokens being added."
)
return token_ids
else:
return token_ids + [self.eos_token_id]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A sequence has the following format:
- single sequence: `X </s>`
- pair of sequences: `A </s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None:
return token_ids_0
else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return token_ids_0 + token_ids_1
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
def __setstate__(self, d):
self.__dict__ = d
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", string.punctuation))
# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
"""Returns canonicalized `text` (puncuation removed).
Args:
text (`str`):
String to be canonicalized.
keep_punctuation_exact_string (`str`, *optional*):
If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
(but will still remove '{' and '}' that appear separately).
"""
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
)
else:
text = self.remove_punctuation(text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
"""
Converts a string to a list of tokens.
"""
tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]
return tokens
@property
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
SPIECE_UNDERLINE.
For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
tokens = self.sp_model.encode(text, out_type=str)
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
class BagelConfig(PretrainedConfig):
def __init__(
self,
visual_gen=True,
visual_und=True,
llm_config=None,
vit_config=None,
vae_config=None,
latent_patch_size=2,
max_latent_size=32,
vit_max_num_patch_per_side=70,
connector_act="gelu_pytorch_tanh",
interpolate_pos=False,
timestep_shift=1.0,
**kwargs
):
super().__init__(**kwargs)
self.visual_gen = visual_gen
self.visual_und = visual_und
self.llm_config = llm_config
self.vit_config = vit_config
self.vae_config = vae_config
self.latent_patch_size = latent_patch_size
self.max_latent_size = max_latent_size
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
self.connector_act = connector_act
self.interpolate_pos = interpolate_pos
self.timestep_shift = timestep_shift
class Bagel(PreTrainedModel):
config_class = BagelConfig
base_model_prefix = 'bagel'
def __init__(
self,
config: BagelConfig, # ← first!
language_model: Optional[Qwen2ForCausalLM] = None,
vit_model: Optional[SiglipVisionModel] = None,
):
if isinstance(config.llm_config, dict):
config.llm_config = Qwen2Config(**config.llm_config)
if isinstance(config.vit_config, dict):
config.vit_config = SiglipVisionConfig(**config.vit_config)
if isinstance(config.vae_config, dict): # ← NEW
config.vae_config = SimpleNamespace(**config.vae_config)
if language_model is None or vit_model is None:
with init_empty_weights(): # ‘meta’ device → 0 RAM
language_model = Qwen2ForCausalLM(config.llm_config)
vit_model = SiglipVisionModel(config.vit_config)
super().__init__(config)
self.language_model = language_model
self.hidden_size = config.llm_config.hidden_size
self.use_moe = "Mo" in config.llm_config.layer_module
self.num_heads = config.llm_config.num_attention_heads
if config.visual_gen:
self.latent_patch_size = config.latent_patch_size
self.timestep_shift = config.timestep_shift
self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
self.max_latent_size = config.max_latent_size
self.latent_channel = config.vae_config.z_channels
self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
self.time_embedder = TimestepEmbedder(self.hidden_size)
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
if config.visual_und:
self.vit_model = vit_model
self.vit_patch_size = config.vit_config.patch_size
self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
self.vit_hidden_size = config.vit_config.hidden_size
self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
self.vit_model.vision_model.embeddings.convert_conv2d_to_linear(config.vit_config, meta=True)
if config.interpolate_pos:
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
else:
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
self.config = config
self._init_weights()
def _init_weights(self):
if self.config.visual_gen:
nn.init.constant_(self.llm2vae.weight, 0)
nn.init.constant_(self.llm2vae.bias, 0)
def forward(
self,
sequence_length: int,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
sample_lens: List[int],
packed_position_ids: torch.LongTensor,
nested_attention_masks: List[torch.Tensor] = None,
split_lens: List[int] = None,
attn_modes: List[str] = None,
# for visual understanding
ce_loss_indexes: Optional[torch.BoolTensor] = None,
packed_label_ids: Optional[torch.LongTensor] = None,
packed_vit_tokens: Optional[torch.Tensor] = None,
packed_vit_token_indexes: Optional[torch.LongTensor] = None,
packed_vit_position_ids: Optional[torch.LongTensor] = None,
vit_token_seqlens: Optional[torch.IntTensor] = None,
# for visual generation
padded_latent: Optional[torch.Tensor] = None,
patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
packed_latent_position_ids: Optional[torch.LongTensor] = None,
packed_vae_token_indexes: Optional[torch.LongTensor] = None,
packed_timesteps: Optional[torch.LongTensor] = None,
mse_loss_indexes: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
"""
Args:
sequence_length: length of sequence.
packed_text_ids: 1-D int tensor, packed text token ids.
packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
sample_lens: A list of N ints, length of each sample in packed_sequence.
nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
-inf means ignore.
packed_position_ids: packed 1-D positions, an image has only one global position shared
by all latent tokens.
packed_vit_tokens: packed patchified image tokens for vit model.
packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
packed_label_ids: 1-D int tensor, packed label token ids.
ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
padded_latent: padded latent from VAE encoder.
patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
"""
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding
if nested_attention_masks is None:
sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
seqlen = sum(sample_lens)
block_mask = create_block_mask(
sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
)
attention_mask = block_mask
else:
attention_mask = nested_attention_masks
if self.config.visual_und:
cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
if self.config.visual_gen:
p = self.latent_patch_size
packed_latent = []
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
packed_latent.append(latent)
packed_latent_clean = torch.cat(packed_latent, dim=0)
noise = torch.randn_like(packed_latent_clean)
packed_timesteps = torch.sigmoid(packed_timesteps)
packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
packed_timestep_embeds = self.time_embedder(packed_timesteps)
latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
packed_sequence[packed_vae_token_indexes] = packed_latent
extra_inputs = {}
if self.use_moe:
packed_und_token_indexes = packed_text_indexes
if packed_vit_token_indexes is not None:
packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_vae_token_indexes,
)
last_hidden_state = self.language_model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_ids=packed_position_ids,
**extra_inputs,
)
mse = None
if self.config.visual_gen:
packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
has_mse = packed_timesteps > 0
mse = (packed_mse_preds - target[has_mse]) ** 2
ce = None
if ce_loss_indexes is not None:
packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
return dict(mse=mse, ce=ce)
def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
packed_text_ids = list()
packed_text_position_ids = list()
text_token_lens = list()
packed_text_indexes = list()
packed_key_value_indexes = list()
curr = 0
newlens, new_rope = list(), list()
for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
text_ids = tokenizer.encode(prompt)
text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
packed_text_indexes.extend(range(curr, curr + len(text_ids)))
newlens.append(curr_kvlen + len(text_ids))
new_rope.append(curr_position_id + len(text_ids))
curr += len(text_ids)
generation_input = {
"text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_text(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.IntTensor,
packed_text_position_ids: torch.LongTensor,
text_token_lens: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=text_token_lens,
packed_query_position_ids=packed_text_position_ids,
packed_query_indexes=packed_text_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
packed_vit_token_indexes = list()
vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
_curr = curr = 0
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids['start_of_image'])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
image_tensor = transforms(image)
vit_position_ids = self.get_flattened_position_ids(
image_tensor.size(1), image_tensor.size(2),
self.vit_patch_size,
max_num_patches_per_side=self.vit_max_num_patch_per_side
)
vit_tokens = patchify(image_tensor, self.vit_patch_size)
packed_vit_tokens.append(vit_tokens)
num_img_tokens = vit_tokens.shape[0]
packed_vit_position_ids.append(vit_position_ids)
vit_token_seqlens.append(num_img_tokens)
packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens
packed_text_ids.append(new_token_ids['end_of_image'])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)
generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
"packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
"packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
"packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_vit(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_vit_tokens: torch.Tensor,
packed_vit_token_indexes: torch.LongTensor,
packed_vit_position_ids: torch.LongTensor,
vit_token_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding
cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + pos_emb
if packed_vit_token_embed.dtype != packed_sequence.dtype:
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
packed_vae_token_indexes = list()
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
_curr = curr = 0
vae_image_tensors = list()
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids['start_of_image'])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
image_tensor = transforms(image)
vae_image_tensors.append(image_tensor)
vae_posiiton_ids = self.get_flattened_position_ids(
image_tensor.size(1), image_tensor.size(2),
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size
)
packed_vae_position_ids.append(vae_posiiton_ids)
H, W = image_tensor.shape[1:]
h = H // self.latent_downsample
w = W // self.latent_downsample
patchified_vae_latent_shapes.append((h, w))
num_img_tokens = w * h
packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens
packed_text_ids.append(new_token_ids['end_of_image'])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)
image_sizes = [item.shape for item in vae_image_tensors]
max_image_size = [max(item) for item in list(zip(*image_sizes))]
padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
for i, image_tensor in enumerate(vae_image_tensors):
padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
generation_input = {
"padded_images": padded_images,
"patchified_vae_latent_shapes": patchified_vae_latent_shapes,
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_timesteps": torch.tensor([timestep]),
"packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_vae(
self,
vae_model,
past_key_values: NaiveCache,
padded_images: torch.Tensor,
patchified_vae_latent_shapes: List,
packed_vae_position_ids: torch.LongTensor,
packed_timesteps: torch.Tensor,
packed_vae_token_indexes: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.Tensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding
padded_latent = vae_model.encode(padded_images)
p = self.latent_patch_size
packed_latent = list()
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
packed_latent.append(latent)
packed_latent = torch.cat(packed_latent, dim=0)
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(packed_timesteps)
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
if packed_latent.dtype != packed_sequence.dtype:
packed_latent = packed_latent.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = packed_latent
extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes
}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
packed_text_ids, packed_text_indexes = list(), list()
packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids['start_of_image'])
packed_text_indexes.append(query_curr)
packed_indexes.append(curr)
curr += 1
query_curr += 1
vae_posiiton_ids = self.get_flattened_position_ids(
H, W,
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size
)
packed_vae_position_ids.append(vae_posiiton_ids)
h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_init_noises.append(
torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
)
packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens
packed_text_ids.append(new_token_ids['end_of_image'])
packed_text_indexes.append(query_curr)
packed_indexes.append(curr)
curr += 1
query_curr += 1
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
packed_seqlens.append(num_image_tokens + 2)
generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_init_noises": torch.cat(packed_init_noises, dim=0),
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
}
return generation_input
def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_indexes.append(curr)
curr += 1
query_curr += 1
h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens
packed_indexes.append(curr)
curr += 1
query_curr += 1
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
generation_input = {
"cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
}
return generation_input
@torch.no_grad
def generate_image(
self,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_init_noises: torch.Tensor,
packed_vae_position_ids: torch.LongTensor,
packed_vae_token_indexes: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_indexes: torch.LongTensor,
past_key_values: NaiveCache,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.LongTensor,
num_timesteps: int = 24,
timestep_shift: float = 1.0,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_interval: Optional[Tuple[float, float]] = [0, 1],
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_text_past_key_values: Optional[NaiveCache] = None,
cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_img_past_key_values: Optional[NaiveCache] = None,
cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
cfg_type: str = "parallel",
):
x_t = packed_init_noises
timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
dts = timesteps[:-1] - timesteps[1:]
timesteps = timesteps[:-1]
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if t > cfg_interval[0] and t <= cfg_interval[1]:
cfg_text_scale_ = cfg_text_scale
cfg_img_scale_ = cfg_img_scale
else:
cfg_text_scale_ = 1.0
cfg_img_scale_ = 1.0
v_t = self._forward_flow(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_vae_position_ids=packed_vae_position_ids,
packed_text_ids=packed_text_ids,
packed_text_indexes=packed_text_indexes,
packed_position_ids=packed_position_ids,
packed_indexes=packed_indexes,
packed_seqlens=packed_seqlens,
key_values_lens=key_values_lens,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
# cfg_text
cfg_text_scale=cfg_text_scale_,
cfg_text_packed_position_ids=cfg_text_packed_position_ids,
cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
cfg_text_key_values_lens=cfg_text_key_values_lens,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
# cfg_img
cfg_img_scale=cfg_img_scale_,
cfg_img_packed_position_ids=cfg_img_packed_position_ids,
cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
cfg_img_key_values_lens=cfg_img_key_values_lens,
cfg_img_past_key_values=cfg_img_past_key_values,
cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
cfg_type=cfg_type,
)
x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
return unpacked_latent
@torch.no_grad
def _forward_flow(
self,
x_t: torch.Tensor,
timestep: torch.LongTensor,
packed_vae_token_indexes: torch.LongTensor,
packed_vae_position_ids: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
key_values_lens: torch.IntTensor,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_text_key_values_lens: Optional[torch.Tensor] = None,
cfg_text_past_key_values: Optional[NaiveCache] = None,
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_img_key_values_lens: Optional[torch.Tensor] = None,
cfg_img_past_key_values: Optional[NaiveCache] = None,
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
cfg_type: str = "parallel",
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding
assert timestep.unique().shape[0] == 1
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(timestep)
x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
if x_t.dtype != packed_sequence.dtype:
x_t = x_t.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = x_t
extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes
}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
v_t = self.llm2vae(output.packed_query_sequence)
v_t = v_t[packed_vae_token_indexes]
if cfg_text_scale > 1.0:
cfg_text_output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_text_packed_position_ids,
packed_query_indexes=cfg_text_packed_query_indexes,
past_key_values=cfg_text_past_key_values,
key_values_lens=cfg_text_key_values_lens,
packed_key_value_indexes=cfg_text_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
if cfg_img_scale > 1.0:
cfg_img_output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_img_packed_position_ids,
packed_query_indexes=cfg_img_packed_query_indexes,
past_key_values=cfg_img_past_key_values,
key_values_lens=cfg_img_key_values_lens,
packed_key_value_indexes=cfg_img_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
if cfg_text_scale > 1.0:
if cfg_renorm_type == "text_channel":
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
v_t_text = v_t_text_ * scale
if cfg_img_scale > 1.0:
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
else:
v_t = v_t_text
else:
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
if cfg_img_scale > 1.0:
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
else:
v_t_ = v_t_text_
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if cfg_renorm_type == "global":
norm_v_t = torch.norm(v_t)
norm_v_t_ = torch.norm(v_t_)
elif cfg_renorm_type == "channel":
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
else:
raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
v_t = v_t_ * scale
else:
# No CFG
pass
return v_t
def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
packed_start_tokens, packed_key_value_indexes = list(), list()
packed_query_position_ids = list()
curr = 0
for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
packed_start_tokens.append(new_token_ids['bos_token_id'])
packed_query_position_ids.append(curr_position_id)
curr += curr_kvlen
generation_input = {
"packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
"packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
}
return generation_input
@torch.no_grad
def generate_text(
self,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_start_tokens: torch.LongTensor,
packed_query_position_ids: torch.LongTensor,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
end_token_id: int = None,
):
step = 0
generated_sequence = []
curr_tokens = packed_start_tokens
while step < max_length:
generated_sequence.append(curr_tokens)
packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
query_lens = torch.ones_like(curr_tokens)
packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
0, len(key_values_lens),
device=key_values_lens.device,
dtype=key_values_lens.dtype
)
uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
for i in range(len(uppacked)):
uppacked[i] += i
packed_key_value_indexes = torch.cat(uppacked, dim=0)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
past_key_values = output.past_key_values
packed_query_sequence = output.packed_query_sequence
pred_logits = self.language_model.lm_head(packed_query_sequence)
if do_sample:
probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
curr_tokens = torch.argmax(pred_logits, dim=-1)
uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
for i in range(len(uppacked)):
uppacked[i] = torch.cat(
[uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
)
packed_key_value_indexes = torch.cat(uppacked, dim=0)
key_values_lens = key_values_lens + 1
packed_query_position_ids = packed_query_position_ids + 1
step += 1
if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
break
output_device = generated_sequence[0].device
return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
# for evaluation
@torch.no_grad()
def chat(
self,
tokenizer,
new_token_ids,
image_transform,
images,
prompt,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
):
device = next(self.parameters()).device
if isinstance(new_token_ids, dict):
for k, v in new_token_ids.items():
if torch.is_tensor(v):
new_token_ids[k] = v.to(device)
elif torch.is_tensor(new_token_ids):
new_token_ids = new_token_ids.to(device)
# prefill
past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
newlens = [0]
new_rope = [0]
# add images
for image in images:
generation_input, newlens, new_rope = self.prepare_vit_images(
curr_kvlens=newlens,
curr_rope=new_rope,
images=[image],
transforms=image_transform,
new_token_ids=new_token_ids,
)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
# add text
generation_input, newlens, new_rope = self.prepare_prompts(
curr_kvlens=newlens,
curr_rope=new_rope,
prompts=[prompt],
tokenizer=tokenizer,
new_token_ids=new_token_ids,
)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
# decode
generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
unpacked_latent = self.generate_text(
past_key_values=past_key_values,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
end_token_id=new_token_ids['eos_token_id'],
**generation_input,
)
output = tokenizer.decode(unpacked_latent[:,0])
output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
return output
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class MLPconnector(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
super().__init__()
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, out_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
self.pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side ** 2, hidden_size),
requires_grad=False
)
self._init_weights()
def _init_weights(self):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
def forward(self, position_ids):
return self.pos_embed[position_ids]
class Qwen2Config(_Qwen2Config):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
is_causal=True,
_attn_implementation="flash_attention_2",
qk_norm=True,
layer_module="Qwen2DecoderLayer",
freeze_und=False,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
use_sliding_window=use_sliding_window,
sliding_window=sliding_window,
max_window_layers=max_window_layers,
attention_dropout=attention_dropout,
is_causal=is_causal,
_attn_implementation=_attn_implementation,
**kwargs,
)
self.qk_norm = qk_norm
self.layer_module = layer_module
self.freeze_und = freeze_und
@dataclass
class BaseNavitOutputWithPast(ModelOutput):
packed_query_sequence: torch.FloatTensor = None
past_key_values: Optional[NaiveCache] = None
def pad_sequence(tensor, pad_size):
H, L, D = tensor.shape
pad_tensor = tensor.new_zeros((H, pad_size, D))
return torch.cat([tensor, pad_tensor], dim=1)
class PackedAttention(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask: List[torch.Tensor],
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
):
packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim)
packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
)
if isinstance(attention_mask, List):
packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim)
packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1)
unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1)
unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
upacked_attn_output = []
for query_states, key_states, value_states, attention_mask_per_sample in zip(
unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size)
packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size)
packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
packed_attn_output = flex_attention(
packed_query_states.unsqueeze(0),
packed_key_states.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size)
packed_attn_output = self.o_proj(packed_attn_output)
return packed_attn_output
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
):
packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
packed_attn_output = self.o_proj(packed_attn_output)
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class PackedAttentionMoT(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.q_norm_moe_gen = nn.Identity()
self.k_norm_moe_gen = nn.Identity()
self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
):
packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim))
packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_sequence_und = packed_sequence[packed_und_token_indexes]
packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen)
packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen)
packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen)
packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
if self.config.freeze_und:
packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach()
packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes])
if self.config.freeze_und:
packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach()
packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes])
packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes])
if self.config.freeze_und:
packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach()
packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes])
packed_cos, packed_sin = packed_position_embeddings
packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1
)
if isinstance(attention_mask, List):
packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim)
packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1)
unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1)
unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
upacked_attn_output = []
for query_states, key_states, value_states, attention_mask_per_sample in zip(
unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size)
packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size)
packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
packed_attn_output = flex_attention(
packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
packed_key_states_.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim)
packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes])
packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes])
return packed_attn_output_
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
):
if mode == 'und':
packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
elif mode == 'gen':
packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim))
packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence)
packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence)
packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence)
packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence)
packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence)
packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence)
packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = packed_query_states.to(torch.float32)
packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes])
packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes])
packed_key_states = packed_key_states.to(torch.float32)
packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes])
packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes])
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
if mode == 'und':
packed_attn_output = self.o_proj(packed_attn_output)
elif mode == 'gen':
packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes])
packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes])
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence = self.mlp(packed_sequence)
packed_sequence = residual + packed_sequence
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoTDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx: Optional[int] = None,
attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
):
super().__init__()
self.hidden_size = config.hidden_size
self.freeze_und = config.freeze_und
self.self_attn = attn_module(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes])
packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
# Self Attention
packed_sequence_ = self.self_attn(
packed_sequence=packed_sequence_,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
packed_sequence = residual + packed_sequence_
# Fully Connected
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.mlp(
self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
)
packed_sequence = residual + packed_sequence_
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.input_layernorm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes])
packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
packed_query_sequence = packed_query_sequence_
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16)
packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16)
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence)
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence)
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoEDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes])
packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
packed_sequence = residual + packed_sequence_new
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
if mode == "und":
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes])
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes])
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
Decoder_layer_dict = {
"Qwen2DecoderLayer": Qwen2DecoderLayer,
"Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
"Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT),
}
class Qwen2Model(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.use_moe = 'Mo' in config.layer_module
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
layer_module = Decoder_layer_dict[config.layer_module]
self.layers = nn.ModuleList(
[layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.use_moe:
self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if self.config.freeze_und:
packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach()
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
assert packed_und_token_indexes is not None
if packed_gen_token_indexes is None:
packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
for decoder_layer in self.layers:
packed_sequence = decoder_layer(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
**extra_inputs
)
if self.use_moe:
packed_sequence_ = torch.zeros_like(packed_sequence)
packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes])
if self.config.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes])
return packed_sequence_
else:
return self.norm(packed_sequence)
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0))
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_query_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
extra_inputs.update(mode=mode)
if mode == 'gen':
assert packed_vae_token_indexes is not None
assert packed_text_indexes is not None
extra_inputs.update(
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
for decoder_layer in self.layers:
packed_query_sequence, past_key_values = decoder_layer(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
**extra_inputs,
)
if self.use_moe:
if mode == "und":
packed_query_sequence = self.norm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes])
packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
packed_query_sequence = packed_query_sequence_
else:
packed_query_sequence = self.norm(packed_query_sequence)
return BaseNavitOutputWithPast(
packed_query_sequence=packed_query_sequence,
past_key_values=past_key_values,
)
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def init_moe(self):
for name, param in self.named_parameters():
if "moe_gen" in name:
original_name = name.replace("_moe_gen", "")
param.data.copy_(self.state_dict()[original_name].data)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
outputs = self.model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
packed_position_ids=packed_position_ids,
attention_mask=attention_mask,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
return outputs
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
outputs = self.model(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
return outputs
class SiglipVisionConfig(_SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
rope=True,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_dropout=attention_dropout,
**kwargs)
self.rope = rope
class RotaryEmbedding2D(torch.nn.Module):
def __init__(self, dim, max_h, max_w, base=10000):
super().__init__()
freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
inv_freq = 1.0 / (base ** freq)
grid_h = torch.arange(0, max_h)
grid_h = grid_h.to(inv_freq.dtype)
grid_h = grid_h[:, None].repeat(1, max_w)
grid_w = torch.arange(0, max_w)
grid_w = grid_w.to(inv_freq.dtype)
grid_w = grid_w[None, :].repeat(max_h, 1)
cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
self.register_buffer("cos_h", cos_h)
self.register_buffer("sin_h", sin_h)
self.register_buffer("cos_w", cos_w)
self.register_buffer("sin_w", sin_w)
def _forward_one_side(self, grid, inv_freq):
freqs = grid[..., None] * inv_freq[None, None, :]
emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
if not config.rope:
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def convert_conv2d_to_linear(self, config, meta=False):
if meta:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
)
else:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
)
W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
self.embed_dim, config.num_channels * self.patch_size ** 2
)
linear_patch_embedding.weight.data = W
linear_patch_embedding.bias.data = self.patch_embedding.bias.data
del self.patch_embedding
self.patch_embedding = linear_patch_embedding
def forward(
self,
packed_pixel_values: torch.FloatTensor,
packed_flattened_position_ids: torch.LongTensor
) -> torch.Tensor:
patch_embeds = self.patch_embedding(packed_pixel_values)
if not self.config.rope:
embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
else:
embeddings = patch_embeds
return embeddings
class SiglipFlashAttention2(SiglipAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
total_q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
if self.config.rope:
qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
query_states = torch.cat([qh, qw], dim=-1)
key_states = torch.cat([kh, kw], dim=-1)
attn_output = flash_attn_varlen_func(
query_states.to(torch.bfloat16),
key_states.to(torch.bfloat16),
value_states.to(torch.bfloat16),
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
)
attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
return attn_output
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipFlashAttention2(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
cos_h=cos_h,
sin_h=sin_h,
cos_w=cos_w,
sin_w=sin_w
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
if config.rope:
max_size = config.image_size // config.patch_size
dim_head = config.hidden_size // config.num_attention_heads
self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
hidden_states = self.embeddings(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids
)
extra_inputs = {}
if self.config.rope:
extra_inputs.update(
cos_h = self.rope.cos_h[packed_flattened_position_ids],
sin_h = self.rope.sin_h[packed_flattened_position_ids],
cos_w = self.rope.cos_w[packed_flattened_position_ids],
sin_w = self.rope.sin_w[packed_flattened_position_ids]
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
**extra_inputs
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "packed_pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
return self.vision_model(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def __init__(
self,
max_size: int,
min_size: int,
stride: int,
max_pixels: int,
interpolation=InterpolationMode.BICUBIC,
antialias=True
):
super().__init__()
self.max_size = max_size
self.min_size = min_size
self.stride = stride
self.max_pixels = max_pixels
self.interpolation = interpolation
self.antialias = antialias
def _make_divisible(self, value, stride):
"""Ensure the value is divisible by the stride."""
return max(stride, int(round(value / stride) * stride))
def _apply_scale(self, width, height, scale):
new_width = round(width * scale)
new_height = round(height * scale)
new_width = self._make_divisible(new_width, self.stride)
new_height = self._make_divisible(new_height, self.stride)
return new_width, new_height
def forward(self, img, img_num=1):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if isinstance(img, torch.Tensor):
height, width = img.shape[-2:]
else:
width, height = img.size
scale = min(self.max_size / max(width, height), 1.0)
scale = max(scale, self.min_size / min(width, height))
new_width, new_height = self._apply_scale(width, height, scale)
# Ensure the number of pixels does not exceed max_pixels
if new_width * new_height > self.max_pixels / img_num:
scale = self.max_pixels / img_num / (new_width * new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
# Ensure longest edge does not exceed max_size
if max(new_width, new_height) > self.max_size:
scale = self.max_size / max(new_width, new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
return TF.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
class ImageTransform:
def __init__(
self,
max_image_size,
min_image_size,
image_stride,
max_pixels=14*14*9*1024,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5]
):
self.stride = image_stride
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
max_size=max_image_size,
min_size=min_image_size,
stride=image_stride,
max_pixels=max_pixels,
)
self.to_tensor_transform = transforms.ToTensor()
self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
def __call__(self, img, img_num=1):
img = self.resize_transform(img, img_num=img_num)
img = self.to_tensor_transform(img)
img = self.normalize_transform(img)
return img
def decolorization(image):
gray_image = image.convert('L')
return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
def downscale(image, scale_factor):
new_width = int(round(image.width * scale_factor))
new_height = int(round(image.height * scale_factor))
new_width = max(1, new_width)
new_height = max(1, new_height)
return image.resize((new_width, new_height), resample=Image.BICUBIC)
def crop(image, crop_factors):
target_h, target_w = crop_factors
img_w, img_h = image.size
if target_h > img_h or target_w > img_w:
raise ValueError("Crop size exceeds image dimensions")
x = random.randint(0, img_w - target_w)
y = random.randint(0, img_h - target_h)
return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
def motion_blur_opencv(image, kernel_size=15, angle=0):
# 线性核
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
# 旋转核
center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
M = cv2.getRotationMatrix2D(center, angle, 1)
rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
# 归一化核
rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
img = np.array(image)
if img.ndim == 2:
blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
else:
# 对于彩色图像,各通道独立卷积
blurred = np.zeros_like(img)
for c in range(img.shape[2]):
blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
return Image.fromarray(blurred.astype(np.uint8))
def shuffle_patch(image, num_splits, gap_size=2):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
patches.append(patch)
current_x += patch_w
current_y += patch_h
random.shuffle(patches)
total_width = sum(patch_widths) + (w_splits - 1) * gap_size
total_height = sum(patch_heights) + (h_splits - 1) * gap_size
new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
current_y = 0 # 当前行的起始 Y 坐标
patch_idx = 0 # 当前处理的块索引
for i in range(h_splits):
current_x = 0 # 当前列的起始 X 坐标
patch_h = patch_heights[i] # 当前行块的高度
for j in range(w_splits):
# 取出打乱后的块
patch = patches[patch_idx]
patch_w = patch_widths[j] # 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image.paste(patch, (current_x, current_y))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x += patch_w + gap_size
patch_idx += 1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y += patch_h + gap_size
return new_image
def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
patches.append(patch)
current_x += patch_w
current_y += patch_h
total_patches = h_splits * w_splits
num_blank = int(total_patches * blank_ratio)
num_blank = max(0, min(num_blank, total_patches))
blank_indices = random.sample(range(total_patches), num_blank)
processed_patches = []
for idx, patch in enumerate(patches):
if idx in blank_indices:
blank_patch = Image.new("RGB", patch.size, color=blank_color)
processed_patches.append(blank_patch)
else:
processed_patches.append(patch)
# 创建结果图像(尺寸与原图一致)
result_image = Image.new("RGB", (img_w, img_h))
current_y = 0
patch_idx = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
# 取出处理后的patch
patch = processed_patches[patch_idx]
patch_w = patch_widths[j]
# 粘贴到原位置
result_image.paste(patch, (current_x, current_y))
current_x += patch_w
patch_idx += 1
current_y += patch_h
return result_image
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
downsample: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(ModelMixin, ConfigMixin):
def __init__(self, params: AutoEncoderParams | None = None, **kwargs):
if params is None:
params = AutoEncoderParams(**kwargs)
super().__init__()
self.register_to_config(**asdict(params))
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))
@classmethod
def from_config(cls, config, **unused):
"""
Diffusers passes us `config` as a *dict* here.
Rebuild the AutoEncoderParams dataclass from that dict and
delegate to the normal constructor.
"""
# keep only keys that exist in AutoEncoderParams
allowed = {f.name for f in fields(AutoEncoderParams)}
params_dict = {k: v for k, v in config.items() if k in allowed}
params = AutoEncoderParams(**params_dict)
return cls(params)
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_ae(local_path: str) -> AutoEncoder:
ae_params = AutoEncoderParams(
resolution=256,
in_channels=3,
downsample=8,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
# Loading the autoencoder
ae = AutoEncoder(ae_params)
if local_path is not None:
sd = load_sft(local_path)
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae, ae_params
VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer.
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''
GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image.
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here'''
class InterleaveInferencer:
def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids):
self.model = model
self.vae_model = vae_model
self.tokenizer = tokenizer
self.vae_transform = vae_transform
self.vit_transform = vit_transform
self.new_token_ids = new_token_ids
def _to_device(self, d, device):
"""Recursively move every tensor in *d* to *device*."""
for k, v in d.items():
if torch.is_tensor(v):
d[k] = v.to(device)
return d
def to(self, device):
self.model = self.model.to(device)
self.vae_model = self.vae_model.to(device)
return self
def init_gen_context(self):
gen_context = {
'kv_lens': [0],
'ropes': [0],
'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers),
}
return gen_context
@torch.no_grad()
def update_context_text(self, text, gen_context):
# used for interleave data, currently only support 1 data inference,
past_key_values = gen_context['past_key_values']
kv_lens = gen_context['kv_lens']
ropes = gen_context['ropes']
generation_input, kv_lens, ropes = self.model.prepare_prompts(
curr_kvlens=kv_lens,
curr_rope=ropes,
prompts=[text],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
generation_input = self._to_device(generation_input,
next(self.model.parameters()).device)
past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
gen_context['kv_lens'] = kv_lens
gen_context['ropes'] = ropes
gen_context['past_key_values'] = past_key_values
return gen_context
@torch.no_grad()
def update_context_image(self, image, gen_context, vae=True, vit=True):
# used for interleave data, currently only support 1 data inference,
assert vae or vit
past_key_values = gen_context['past_key_values']
kv_lens = gen_context['kv_lens']
ropes = gen_context['ropes']
device = next(self.model.parameters()).device
if vae:
## update vae
generation_input, kv_lens, ropes = self.model.prepare_vae_images(
curr_kvlens=kv_lens,
curr_rope=ropes,
images=[image],
transforms=self.vae_transform,
new_token_ids=self.new_token_ids,
)
generation_input = self._to_device(generation_input, device)
past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
if vit:
## update vit
generation_input, kv_lens, ropes = self.model.prepare_vit_images(
curr_kvlens=kv_lens,
curr_rope=ropes,
images=[image],
transforms=self.vit_transform,
new_token_ids=self.new_token_ids,
)
generation_input = self._to_device(generation_input, device)
past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
gen_context['kv_lens'] = kv_lens
gen_context['ropes'] = ropes
gen_context['past_key_values'] = past_key_values
return gen_context
@torch.no_grad()
def gen_image(
self,
image_shape,
gen_context,
cfg_text_scale=4.0,
cfg_img_scale=1.5,
cfg_text_precontext=None,
cfg_img_precontext=None,
cfg_interval=(0.4, 1.0),
cfg_renorm_min=0.0,
cfg_renorm_type="global",
num_timesteps=50,
timestep_shift=3.0
):
# print(cfg_renorm_type)
device = next(self.model.parameters()).device
past_key_values = gen_context['past_key_values']
kv_lens = gen_context['kv_lens']
ropes = gen_context['ropes']
generation_input = self.model.prepare_vae_latent(
curr_kvlens=kv_lens,
curr_rope=ropes,
image_sizes=[image_shape],
new_token_ids=self.new_token_ids,
)
generation_input = self._to_device(generation_input, device)
# text cfg
cfg_text_past_key_values = cfg_text_precontext['past_key_values']
kv_lens_cfg = cfg_text_precontext['kv_lens']
ropes_cfg = cfg_text_precontext['ropes']
generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[image_shape],
)
generation_input_cfg_text = self._to_device(generation_input_cfg_text, device)
# img cfg
cfg_img_past_key_values = cfg_img_precontext['past_key_values']
kv_lens_cfg = cfg_img_precontext['kv_lens']
ropes_cfg = cfg_img_precontext['ropes']
generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[image_shape],
)
generation_input_cfg_img = self._to_device(generation_input_cfg_img, device)
unpacked_latent = self.model.generate_image(
past_key_values=past_key_values,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_img_past_key_values=cfg_img_past_key_values,
num_timesteps=num_timesteps,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=cfg_interval,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
timestep_shift=timestep_shift,
**generation_input,
cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
)
image = self.decode_image(unpacked_latent[0], image_shape)
return image
def decode_image(self, latent, image_shape):
H, W = image_shape
h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
latent = torch.einsum("nhwpqc->nchpwq", latent)
latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
image = self.vae_model.decode(latent)
image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
return image
@torch.no_grad()
def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
gen_context = deepcopy(gen_context)
past_key_values = gen_context['past_key_values']
kv_lens = gen_context['kv_lens']
ropes = gen_context['ropes']
device = next(self.model.parameters()).device
generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
generation_input = self._to_device(generation_input, device)
unpacked_latent = self.model.generate_text(
past_key_values=past_key_values,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
end_token_id=self.new_token_ids['eos_token_id'],
**generation_input,
)
output = self.tokenizer.decode(unpacked_latent[:,0])
output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
return output
@torch.no_grad()
def interleave_inference(
self,
input_lists: List[Union[str, Image.Image]],
think=False,
understanding_output=False,
max_think_token_n=1000,
do_sample=False,
text_temperature=0.3,
cfg_text_scale=3.0,
cfg_img_scale=1.5,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="global",
image_shapes=(1024, 1024),
) -> List[Union[str, Image.Image]]:
output_list = []
gen_context = self.init_gen_context()
cfg_text_context = deepcopy(gen_context)
cfg_img_context = deepcopy(gen_context)
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
if think:
if understanding_output:
system_prompt = VLM_THINK_SYSTEM_PROMPT
else:
system_prompt = GEN_THINK_SYSTEM_PROMPT
gen_context = self.update_context_text(system_prompt, gen_context)
cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
for input_term in input_lists:
if isinstance(input_term, str):
cfg_text_context = deepcopy(gen_context)
gen_context = self.update_context_text(input_term, gen_context)
cfg_img_context = self.update_context_text(input_term, cfg_img_context)
elif isinstance(input_term, Image.Image):
input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
image_shapes = input_term.size[::-1]
cfg_text_context = deepcopy(gen_context)
else:
raise ValueError(f"Unsupported input type: {type(input_term)}")
if understanding_output:
gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
output_list.append(gen_text)
else:
if think:
gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
gen_context = self.update_context_text(gen_text, gen_context)
output_list.append(gen_text)
img = self.gen_image(
image_shapes,
gen_context,
cfg_text_precontext=cfg_text_context,
cfg_img_precontext=cfg_img_context,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=cfg_interval,
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
)
output_list.append(img)
return output_list
def __call__(
self,
image: Optional[Image.Image] = None,
text: Optional[str] = None,
**kargs
) -> Dict[str, Any]:
output_dict = {'image': None, 'text': None}
if image is None and text is None:
print('Please provide at least one input: either an image or text.')
return output_dict
input_list = []
if image is not None:
input_list.append(image)
if text is not None:
input_list.append(text)
output_list = self.interleave_inference(input_list, **kargs)
for i in output_list:
if isinstance(i, Image.Image):
output_dict['image'] = i
elif isinstance(i, str):
output_dict['text'] = i
return output_dict
class BagelPipeline(DiffusionPipeline):
model_cpu_offload_seq = "bagel_model"
def __init__(self, bagel_model, vae, tokenizer):
super().__init__()
self.register_modules(
bagel_model = bagel_model,
vae = vae,
tokenizer = tokenizer,
)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
self._inferencer = InterleaveInferencer(
model = bagel_model,
vae_model = vae,
tokenizer = tokenizer,
vae_transform= ImageTransform(1024, 512, 16),
vit_transform= ImageTransform(980, 224, 14),
new_token_ids= new_token_ids,
)
def __call__(
self,
*,
image: Optional[Image.Image] = None,
text: Optional[str] = None,
think: bool = False,
understanding_output: bool = False,
**infer_kwargs
) -> PipelineOutput:
"""
Supports:
- text→image (pass text=…)
- text→image + think (+ think=True)
- image→image edit (pass image=…, text=…)
- image→image+think (+ think=True)
- image→understanding (+ understanding_output=True)
Any other kwargs (cfg_text_scale, num_timesteps, etc.) override the defaults below.
"""
if text is not None and image is None:
defaults: Dict[str, Any] = {
"cfg_text_scale": 4.0,
"cfg_img_scale": 1.0,
"cfg_interval": (0.4, 1.0),
"timestep_shift": 3.0,
"num_timesteps": 50,
"cfg_renorm_min": 0.0,
"cfg_renorm_type": "global",
}
if think:
defaults.update({
"max_think_token_n": 1000,
"do_sample": False,
"text_temperature": 0.3,
})
elif image is not None and text is not None and not understanding_output:
defaults = {
"cfg_text_scale": 4.0,
"cfg_img_scale": 2.0,
"cfg_interval": (0.0, 1.0),
"timestep_shift": 3.0,
"num_timesteps": 50,
"cfg_renorm_min": 0.0,
"cfg_renorm_type": "text_channel",
}
if think:
defaults.update({
"max_think_token_n": 1000,
"do_sample": False,
"text_temperature": 0.3,
})
elif image is not None and understanding_output:
defaults = {
"max_think_token_n": 1000,
"do_sample": False,
}
else:
defaults = {}
for k, v in defaults.items():
infer_kwargs.setdefault(k, v)
result: Dict[str, Any] = self._inferencer(
image=image,
text=text,
think=think,
understanding_output=understanding_output,
**infer_kwargs,
)
out: Dict[str, Any] = {}
if result.get("image") is not None:
out["images"] = [result["image"]]
if result.get("text") is not None:
out["text"] = result["text"]
return out
def to(self, device):
super().to(device)
if hasattr(self, "_inferencer"):
self._inferencer.to(device)
return self