| from .clip import clip |
| from PIL import Image |
| import torch.nn as nn |
|
|
|
|
| CHANNELS = { |
| "RN50" : 1024, |
| "ViT-L/14" : 768 |
| } |
|
|
| class CLIPModel(nn.Module): |
| def __init__(self, name, num_classes=1): |
| super(CLIPModel, self).__init__() |
|
|
| self.model, self.preprocess = clip.load(name, device="cpu") |
| self.fc = nn.Linear( CHANNELS[name], num_classes ) |
| |
|
|
| def forward(self, x, return_feature=False): |
| features = self.model.encode_image(x) |
| |
| """ |
| 使用的是ViT-Large, 共24层 |
| 选择第24、22、20层的[cls]feature做加权平均 |
| """ |
| if return_feature: |
| return features['after_projection'] |
| |
| |
| |
| |
| |
| features = features['res_output'] |
| return self.fc(features) |
|
|
|
|