| """ |
| Image embedding service for generating vector embeddings from images |
| """ |
|
|
| import io |
| import os |
| import base64 |
| from typing import List, Tuple |
|
|
| import open_clip |
| import torch |
| from fastapi import UploadFile, HTTPException |
| from PIL import Image |
| import torch |
|
|
|
|
| class ImageEmbeddingModel: |
| """Class for handling image embedding using CLIP model""" |
| |
| def __init__(self, model_name: str): |
| self.model_name = model_name |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model, self.preprocess_train, self.preprocess_val = self._initialize_model() |
| |
| def _initialize_model(self) -> Tuple: |
| """Initialize the CLIP model for image embeddings""" |
| model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(self.model_name) |
| tokenizer = open_clip.get_tokenizer(self.model_name) |
| model.to(self.device) |
| model.eval() |
| return model, preprocess_train, preprocess_val |
| |
| def get_embedding_from_pil(self, image: Image.Image) -> List[float]: |
| """Get embedding from PIL image""" |
| processed_image = self.preprocess_val(image).unsqueeze(0).to(self.device) |
|
|
| |
| if self.device == 'cuda': |
| autocast_context = torch.amp.autocast(device_type='cuda') |
| else: |
| |
| autocast_context = torch.amp.autocast(device_type='cpu', dtype=torch.float32) |
| with torch.no_grad(), autocast_context: |
| image_features = self.model.encode_image(processed_image, normalize=True) |
|
|
|
|
| return image_features.cpu().numpy()[0].tolist() |
| |
| async def get_embedding_from_upload(self, image_file: UploadFile) -> List[float]: |
| """Get embedding from uploaded image file""" |
| try: |
| contents = await image_file.read() |
| img = Image.open(io.BytesIO(contents)).convert("RGB") |
| return self.get_embedding_from_pil(img) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}") |
| |
| def get_embedding_from_base64(self, base64_data: str) -> List[float]: |
| """Get embedding from base64 encoded image""" |
| try: |
| |
| if ',' in base64_data: |
| base64_data = base64_data.split(',')[1] |
| |
| image_bytes = base64.b64decode(base64_data) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
| return self.get_embedding_from_pil(image) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}") |
| |
| def get_embeddings_from_folder(self, image_folder: str) -> List[List[float]]: |
| """Get embeddings from all images in a folder""" |
| embeddings = [] |
|
|
| if not os.path.exists(image_folder): |
| raise HTTPException(status_code=404, detail=f"Folder not found: {image_folder}") |
|
|
| for image_name in os.listdir(image_folder): |
| if image_name.lower().endswith(('.png', '.jpg', '.jpeg')): |
| try: |
| image_path = os.path.join(image_folder, image_name) |
| img = Image.open(image_path).convert("RGB") |
| embeddings.append(self.get_embedding_from_pil(img)) |
| except Exception as e: |
| print(f"Error processing {image_name}: {str(e)}") |
|
|
| return embeddings |
|
|