Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| from uuid import uuid4 | |
| import numpy as np | |
| from chromadb import Client | |
| from chromadb.config import Settings | |
| class ChromaVectorStore: | |
| def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = True): | |
| self.embedding_dim = embedding_dim | |
| self.collection_name = os.getenv("CHROMA_COLLECTION", "repo_qa_chunks") | |
| self.upsert_batch_size = max(1, int(os.getenv("CHROMA_UPSERT_BATCH_SIZE", "64"))) | |
| self.persist_path = os.getenv("CHROMA_PATH", index_path or "./data/chroma") | |
| self.persist = persist | |
| self.client = self._create_client() | |
| self.collection = self._ensure_collection() | |
| def _create_client(self): | |
| if self.persist: | |
| Path(self.persist_path).mkdir(parents=True, exist_ok=True) | |
| return Client( | |
| Settings( | |
| is_persistent=True, | |
| persist_directory=self.persist_path, | |
| anonymized_telemetry=False, | |
| ) | |
| ) | |
| return Client(Settings(anonymized_telemetry=False)) | |
| def _ensure_collection(self): | |
| return self.client.get_or_create_collection( | |
| name=self.collection_name, | |
| embedding_function=None, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[str]: | |
| if embeddings.size == 0: | |
| return [] | |
| embeddings = embeddings.astype("float32") | |
| if embeddings.ndim == 1: | |
| embeddings = embeddings.reshape(1, -1) | |
| ids = [uuid4().hex for _ in metadata] | |
| total_points = len(ids) | |
| for start in range(0, total_points, self.upsert_batch_size): | |
| end = start + self.upsert_batch_size | |
| batch_ids = ids[start:end] | |
| batch_embeddings = embeddings[start:end].tolist() | |
| batch_metadata = [] | |
| batch_documents = [] | |
| for idx, meta in zip(batch_ids, metadata[start:end]): | |
| payload = self._sanitize_metadata(meta) | |
| payload["id"] = idx | |
| batch_metadata.append(payload) | |
| batch_documents.append(str(meta.get("content") or "")) | |
| batch_number = (start // self.upsert_batch_size) + 1 | |
| total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size | |
| print( | |
| f"[chroma] Adding batch {batch_number}/{total_batches} " | |
| f"points={len(batch_ids)} progress={start}/{total_points}", | |
| flush=True, | |
| ) | |
| self.collection.add( | |
| ids=batch_ids, | |
| embeddings=batch_embeddings, | |
| metadatas=batch_metadata, | |
| documents=batch_documents, | |
| ) | |
| return ids | |
| def search( | |
| self, | |
| query_embedding: np.ndarray, | |
| k: int = 10, | |
| repo_filter: Optional[int] = None, | |
| ) -> List[Tuple[float, dict]]: | |
| if query_embedding.ndim == 1: | |
| query_embedding = query_embedding.reshape(1, -1) | |
| query_embedding = query_embedding.astype("float32") | |
| where = {"repository_id": repo_filter} if repo_filter is not None else None | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding[0].tolist()], | |
| n_results=k, | |
| where=where, | |
| include=["documents", "metadatas", "distances"], | |
| ) | |
| ids = (results.get("ids") or [[]])[0] | |
| documents = (results.get("documents") or [[]])[0] | |
| metadatas = (results.get("metadatas") or [[]])[0] | |
| distances = (results.get("distances") or [[]])[0] | |
| hits = [] | |
| for idx, document, meta, distance in zip(ids, documents, metadatas, distances): | |
| payload = dict(meta or {}) | |
| payload["id"] = payload.get("id") or idx | |
| payload["content"] = document or "" | |
| hits.append((self._distance_to_score(distance), payload)) | |
| return hits | |
| def remove_repository(self, repo_id: int): | |
| self.collection.delete(where={"repository_id": repo_id}) | |
| def clear(self): | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| except Exception: | |
| pass | |
| self.collection = self._ensure_collection() | |
| def save(self): | |
| persist = getattr(self.client, "persist", None) | |
| if callable(persist): | |
| persist() | |
| def load(self): | |
| self.collection = self._ensure_collection() | |
| def keep_alive(self) -> dict: | |
| heartbeat = getattr(self.client, "heartbeat", None) | |
| if callable(heartbeat): | |
| heartbeat() | |
| return self.get_stats() | |
| def get_stats(self) -> dict: | |
| return { | |
| "total_vectors": self.collection.count(), | |
| "embedding_dim": self.embedding_dim, | |
| "collection_name": self.collection_name, | |
| "persist_path": self.persist_path if self.persist else None, | |
| } | |
| def _sanitize_metadata(meta: dict) -> dict: | |
| sanitized = {} | |
| for key, value in meta.items(): | |
| if key == "content": | |
| continue | |
| if value is None: | |
| sanitized[key] = "" | |
| elif isinstance(value, (str, int, float, bool)): | |
| sanitized[key] = value | |
| else: | |
| sanitized[key] = str(value) | |
| return sanitized | |
| def _distance_to_score(distance: float) -> float: | |
| if distance is None: | |
| return 0.0 | |
| return max(0.0, min(1.0, 1.0 - float(distance))) | |