| from functools import partial |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class BaseEncoder(nn.Module): |
| def __init__(self, parent: nn.Module) -> None: |
| super().__init__() |
| self._parent = [parent] |
|
|
| @property |
| def parent(self) -> nn.Module: |
| return self._parent[0] |
|
|
|
|
| class BasicImageEncoder(BaseEncoder): |
| def __init__( |
| self, |
| parent: torch.nn.Module, |
| start_tokens: Optional[str] = None, |
| end_tokens: Optional[str] = "\n", |
| ) -> None: |
| super().__init__(parent) |
| self.start_tokens = start_tokens |
| self.end_tokens = end_tokens |
|
|
| def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
| if tokens is None: |
| return None |
| token_ids = self.parent.tokenizer(tokens).input_ids |
| token_ids = torch.tensor(token_ids, device=self.parent.device) |
| return self.parent.llm_model_embed_tokens(token_ids) |
|
|
| def _process_features( |
| self, |
| features: torch.Tensor, |
| start_token_embeds: Optional[torch.Tensor], |
| end_token_embeds: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if start_token_embeds is not None: |
| features = torch.cat([start_token_embeds, features], dim=0) |
| if end_token_embeds is not None: |
| features = torch.cat([features, end_token_embeds], dim=0) |
| return features |
|
|
| def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]: |
| images = torch.stack(images, dim=0) |
| features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) |
| process_features = partial( |
| self._process_features, |
| start_token_embeds=self.embed_tokens(self.start_tokens), |
| end_token_embeds=self.embed_tokens(self.end_tokens), |
| ) |
| return [process_features(f).to(device) for f in features] |
|
|
|
|
| class BasicVideoEncoder(BaseEncoder): |
| def __init__( |
| self, |
| parent: torch.nn.Module, |
| start_tokens: Optional[str] = None, |
| end_tokens: Optional[str] = "\n", |
| ) -> None: |
| super().__init__(parent) |
| self.start_tokens = start_tokens |
| self.end_tokens = end_tokens |
|
|
| def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
| if tokens is None: |
| return None |
| token_ids = self.parent.tokenizer(tokens).input_ids |
| token_ids = torch.tensor(token_ids, device=self.parent.device) |
| return self.parent.llm_model_embed_tokens(token_ids) |
|
|
| def _process_features( |
| self, |
| features: torch.Tensor, |
| start_token_embeds: Optional[torch.Tensor], |
| end_token_embeds: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if start_token_embeds is not None: |
| start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) |
| features = torch.cat([start_embeds, features], dim=1) |
| if end_token_embeds is not None: |
| end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) |
| features = torch.cat([features, end_embeds], dim=1) |
| return features.flatten(0, 1) |
|
|
| def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
| num_frames = [video.shape[0] for video in videos] |
| images = torch.cat(videos, dim=0) |
| features = self.parent.encode_images(images) |
| features = torch.split(features, num_frames) |
| process_features = partial( |
| self._process_features, |
| start_token_embeds=self.embed_tokens(self.start_tokens), |
| end_token_embeds=self.embed_tokens(self.end_tokens), |
| ) |
| return [process_features(f) for f in features] |
|
|
| def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor: |
| return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1) |
|
|
| class TSPVideoEncoder(BasicVideoEncoder): |
| def __init__( |
| self, |
| parent: torch.nn.Module, |
| start_tokens: Optional[str] = None, |
| end_tokens: Optional[str] = "\n", |
| sep_tokens: Optional[str] = None, |
| ) -> None: |
| super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens) |
| self.pool_sizes = [[8, 1, 1]] |
| self.sep_tokens = sep_tokens |
|
|
| def _process_features( |
| self, |
| inputs: torch.Tensor, |
| start_token_embeds: Optional[torch.Tensor], |
| end_token_embeds: Optional[torch.Tensor], |
| sep_token_embeds: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| nt, ns = inputs.shape[:2] |
| nl = int(ns**0.5) |
| outputs = [] |
| for pool_size in self.pool_sizes: |
| features = inputs.view(nt, nl, nl, -1) |
| for dim, p in enumerate(pool_size): |
| features = pool(features, p, dim=dim) |
| features = features.flatten(1, 2) |
| features = super()._process_features( |
| features, |
| start_token_embeds=start_token_embeds, |
| end_token_embeds=end_token_embeds, |
| ) |
| if sep_token_embeds is not None: |
| features = torch.cat([features, sep_token_embeds], dim=0) |
| outputs.append(features) |
| return torch.cat(outputs, dim=0) |
|
|
| def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
| num_frames = [video.shape[0] for video in videos] |
| images = torch.cat(videos, dim=0) |
| features = self.parent.encode_images(images) |
| features = torch.split(features, num_frames) |
| process_features = partial( |
| self._process_features, |
| start_token_embeds=self.embed_tokens(self.start_tokens), |
| end_token_embeds=self.embed_tokens(self.end_tokens), |
| sep_token_embeds=self.embed_tokens(self.sep_tokens), |
| ) |
| return [process_features(f) for f in features] |
|
|