| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from transformers import AutoProcessor, AutoConfig |
| | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel |
| | from tqdm import tqdm |
| | from safetensors.torch import load_file |
| | import os |
| |
|
| | class Qwen2_5_VL_ImageEncoder: |
| | def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16): |
| | self.device = device |
| | self.dtype = dtype |
| | |
| | print(f"Loading processor and model from {model_path}...") |
| | self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True) |
| | |
| | config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct') |
| | config = config.vision_config |
| | |
| | self.model = Qwen2_5_VisionTransformerPretrainedModel(config) |
| | |
| | safe_path = os.path.join(model_path, "model.safetensors") |
| | state_dict = load_file(safe_path) |
| | self.model.load_state_dict(state_dict, strict=True) |
| |
|
| | self.model.to(device=self.device, dtype=self.dtype) |
| | self.model.eval() |
| | print("Model loaded successfully.") |
| | |
| | def _process_batch_forward(self, images): |
| | """Internal helper to run forward pass on a single batch.""" |
| | |
| | messages_list = [ |
| | [ |
| | {"type": "image", "image": img}, |
| | {"type": "text", "text": "Describe this image."}, |
| | ] for img in images |
| | ] |
| | |
| | |
| | text_inputs = [ |
| | self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
| | for msg in messages_list |
| | ] |
| | |
| | |
| | inputs = self.processor( |
| | images=images, |
| | text=text_inputs, |
| | return_tensors="pt", |
| | padding=True |
| | ) |
| | |
| | |
| | pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype) |
| | grid_thw = inputs["image_grid_thw"].to(self.device) |
| |
|
| | |
| | outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw) |
| | hidden_states = outputs |
| |
|
| | |
| | if grid_thw.dim() == 3 and grid_thw.size(1) == 1: |
| | grid_thw = grid_thw.squeeze(1) |
| |
|
| | batch_size = grid_thw.shape[0] |
| | |
| | |
| | H, W = grid_thw[:, 1], grid_thw[:, 2] |
| | sizes = ((H // 2) * (W // 2)).long() |
| | |
| | |
| | total_tokens = hidden_states.shape[0] |
| | if sizes.sum().item() != total_tokens: |
| | sizes[-1] += (total_tokens - sizes.sum().item()) |
| |
|
| | |
| | batch_indices = torch.repeat_interleave( |
| | torch.arange(batch_size, device=self.device), |
| | sizes |
| | ) |
| |
|
| | |
| | pooled_sum = torch.zeros( |
| | (batch_size, hidden_states.shape[-1]), |
| | dtype=self.dtype, |
| | device=self.device |
| | ) |
| | pooled_sum.index_add_(0, batch_indices, hidden_states) |
| |
|
| | |
| | counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0) |
| | embeds = pooled_sum / counts |
| |
|
| | |
| | embeds = F.normalize(embeds, p=2, dim=-1) |
| | |
| | return embeds.cpu() |
| |
|
| | @torch.no_grad() |
| | def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True): |
| | """ |
| | Args: |
| | images: List of PIL Images. |
| | batch_size: Number of images to process at once. |
| | Returns: |
| | torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim] |
| | """ |
| | all_embeddings = [] |
| | |
| | iterator = range(0, len(images), batch_size) |
| | if show_progress: |
| | iterator = tqdm(iterator, desc="Encoding Batches", unit="batch") |
| |
|
| | for i in iterator: |
| | batch_images = images[i : i + batch_size] |
| | |
| | |
| | batch_images = [img.convert("RGB") for img in batch_images] |
| | |
| | try: |
| | batch_embeds = self._process_batch_forward(batch_images) |
| | all_embeddings.append(batch_embeds) |
| | except Exception as e: |
| | print(f"Error processing batch starting at index {i}: {e}") |
| | |
| | raise e |
| |
|
| | if not all_embeddings: |
| | return torch.empty(0) |
| |
|
| | |
| | return torch.cat(all_embeddings, dim=0) |
| |
|
| | |
| | if __name__ == "__main__": |
| |
|
| | MODEL_PATHS = [ |
| | "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500", |
| | "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"] |
| | for MODEL_PATH in MODEL_PATHS: |
| | encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH) |
| |
|
| | from datasets import load_dataset |
| | spearmans = [] |
| | for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]: |
| | dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"] |
| | anchors = dataset["sentence1"] |
| | positive = dataset["sentence2"] |
| |
|
| | embeddings1 = encoder.encode_batch(anchors, batch_size=32) |
| | embeddings2 = encoder.encode_batch(positive, batch_size=32) |
| | groundtruth = dataset["score"] |
| |
|
| |
|
| | from sklearn.metrics.pairwise import paired_cosine_distances |
| | import numpy as np |
| | from scipy.stats import spearmanr |
| |
|
| | embeddings1 = embeddings1.cpu().float().numpy() |
| | embeddings2 = embeddings2.cpu().float().numpy() |
| |
|
| | cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2) |
| | spearman_corr, _ = spearmanr(cos_sim, groundtruth) |
| | spearmans.append(round(spearman_corr,2)) |
| | print("Spearman correlation:", spearmans) |