| """ |
| Vector database service for interacting with Qdrant |
| """ |
|
|
| from typing import List, Dict, Any |
|
|
| from fastapi import HTTPException |
| from qdrant_client import QdrantClient |
| from qdrant_client.models import Distance, PointStruct, VectorParams |
|
|
| class VectorDatabaseClient: |
| """Class for interacting with Qdrant vector database""" |
| |
| def __init__(self, url: str, api_key: str, collection_name: str, embedding_size: int): |
| self.url = url |
| self.api_key = api_key |
| self.collection_name = collection_name |
| self.embedding_size = embedding_size |
| self.client = QdrantClient(url=url, api_key=api_key) |
| |
| def ensure_collection_exists(self) -> None: |
| """Ensure the Qdrant collection exists""" |
| collections = self.client.get_collections() |
| collection_names = [c.name for c in collections.collections] |
| |
| if self.collection_name not in collection_names: |
| self.client.create_collection( |
| collection_name=self.collection_name, |
| vectors_config=VectorParams( |
| size=self.embedding_size, |
| distance=Distance.COSINE |
| ) |
| ) |
| print(f"✅ Collection '{self.collection_name}' created.") |
| else: |
| print(f"ℹ️ Collection '{self.collection_name}' already exists.") |
| |
| def add_image(self, image_id: str, embedding: List[float], payload: Dict[str, Any]) -> None: |
| """Add an image embedding to the database""" |
| self.client.upsert( |
| collection_name=self.collection_name, |
| points=[ |
| PointStruct( |
| id=image_id, |
| vector=embedding, |
| payload=payload |
| ) |
| ] |
| ) |
| |
| def search_by_vector(self, embedding: List[float], limit: int = 1) -> List[Dict[str, Any]]: |
| """Search for similar images using an embedding vector""" |
| results = self.client.search( |
| collection_name=self.collection_name, |
| query_vector=embedding, |
| limit=limit |
| ) |
| |
| return [ |
| { |
| "id": r.id, |
| "score": r.score, |
| "payload": r.payload |
| } |
| for r in results |
| ] |
| |
| def list_collections(self) -> List[str]: |
| """List all collections in the database""" |
| return [c.name for c in self.client.get_collections().collections] |
|
|