| import json |
| from chromadb import Client, Settings, EmbeddingFunction |
| from pprint import pprint |
| import random |
| import os |
| from sentence_transformers import SentenceTransformer |
| import torch |
|
|
| |
| CHROMA_URI = "./Data/database" |
| EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb" |
| VECTOR_DIM = 768 |
|
|
| class BioEmbeddingFunction(EmbeddingFunction): |
| def __init__(self): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
| self.model.to(self.device) |
| |
| def __call__(self, input: list[str]) -> list[list[float]]: |
| embeddings = self.model.encode( |
| input, |
| normalize_embeddings=True, |
| convert_to_numpy=True |
| ) |
| return embeddings.tolist() |
|
|
| def test_database(): |
| print("="*50) |
| print("开始测试数据库") |
| print("="*50) |
| |
| |
| client = Client(Settings( |
| persist_directory=CHROMA_URI, |
| anonymized_telemetry=False, |
| is_persistent=True |
| )) |
| |
|
|
| embedding_func = BioEmbeddingFunction() |
| collection = client.get_or_create_collection( |
| name="healthcare_qa", |
| embedding_function=embedding_func |
| ) |
| |
| |
| print("\n1. 数据库基本信息:") |
| print(f"数据库位置: {os.path.abspath(CHROMA_URI)}") |
| print(f"数据库大小: {os.path.getsize(CHROMA_URI) / 1024 / 1024:.2f} MB") |
| print(f"总条目数: {collection.count()} 条") |
| print(f"使用的嵌入模型: {EMBEDDING_MODEL_NAME}") |
| |
| |
| print("\n2. 随机样本展示:") |
| total_items = collection.count() |
| sample_size = min(2, total_items) |
| random_indices = random.sample(range(total_items), sample_size) |
| |
| results = collection.get( |
| ids=[str(i) for i in random_indices], |
| include=["documents", "metadatas"] |
| ) |
| |
| for i, (doc, metadata) in enumerate(zip(results['documents'], results['metadatas']), 1): |
| print(f"\n样本 {i}:") |
| print("-" * 40) |
| print("文档内容:") |
| print(doc) |
| print("\n元数据:") |
| pprint(metadata) |
| print("-" * 40) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| print("\n3. 测试查询功能:") |
| query = "diabetes" |
| results = collection.query( |
| query_texts=[query], |
| n_results=1, |
| include=["documents", "metadatas", "distances"] |
| ) |
| |
| print(f"\n使用查询词 '{query}' 的结果:") |
| for i, (doc, metadata, distance) in enumerate(zip( |
| results['documents'][0], |
| results['metadatas'][0], |
| results['distances'][0] |
| ), 1): |
| print(f"\n结果 {i}:") |
| print("-" * 40) |
| print(f"相似度得分: {1 - distance:.4f}") |
| print("\n文档内容:") |
| print(doc) |
| print("\n元数据:") |
| pprint(metadata) |
| print("-" * 40) |
| |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| test_database() |
|
|