| import numpy as np |
| from sklearn.model_selection import StratifiedKFold |
| from sklearn.metrics import classification_report, accuracy_score, f1_score |
| from sklearn.ensemble import RandomForestClassifier |
| import joblib |
| from chromadb import Client, Settings |
| import os |
| import json |
| from datetime import datetime |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| class TopicClassifier: |
| def __init__(self, chroma_uri: str = "./Data/database"): |
| """初始化分类器 |
| |
| Args: |
| chroma_uri: ChromaDB数据库路径 |
| """ |
| self.chroma_uri = chroma_uri |
| self.client = Client(Settings( |
| persist_directory=chroma_uri, |
| anonymized_telemetry=False, |
| is_persistent=True |
| )) |
| self.collection = self.client.get_collection("healthcare_qa") |
| self.model = None |
| self.X = None |
| self.y = None |
| |
| def load_data(self): |
| """从数据库加载数据和标签""" |
| print("正在加载数据...") |
| |
| |
| result = self.collection.get(include=["embeddings", "metadatas"]) |
| self.X = np.array(result["embeddings"]) |
| |
| |
| self.y = [] |
| for metadata in result["metadatas"]: |
| cluster = metadata.get("cluster", "noise") |
| |
| if cluster == "noise": |
| self.y.append(-1) |
| else: |
| self.y.append(int(cluster.split("_")[1])) |
| self.y = np.array(self.y) |
| |
| |
| mask = self.y != -1 |
| self.X = self.X[mask] |
| self.y = self.y[mask] |
| |
| print(f"数据加载完成,特征形状: {self.X.shape}") |
| print(f"类别数量: {len(np.unique(self.y))}") |
| |
| def train_and_evaluate(self, n_splits=5): |
| """使用5折交叉验证训练和评估模型""" |
| if self.X is None or self.y is None: |
| self.load_data() |
| |
| print(f"\n开始{n_splits}折交叉验证...") |
| |
| |
| skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
| |
| |
| fold_scores = { |
| 'accuracy': [], |
| 'macro_f1': [], |
| 'weighted_f1': [] |
| } |
| |
| for fold, (train_idx, val_idx) in enumerate(skf.split(self.X, self.y), 1): |
| print(f"\n第 {fold} 折验证:") |
| |
| |
| X_train, X_val = self.X[train_idx], self.X[val_idx] |
| y_train, y_val = self.y[train_idx], self.y[val_idx] |
| |
| |
| print("训练模型...") |
| self.model = RandomForestClassifier( |
| n_estimators=100, |
| max_depth=None, |
| n_jobs=-1, |
| random_state=42 |
| ) |
| self.model.fit(X_train, y_train) |
| |
| |
| y_pred = self.model.predict(X_val) |
| |
| |
| accuracy = accuracy_score(y_val, y_pred) |
| macro_f1 = f1_score(y_val, y_pred, average='macro') |
| weighted_f1 = f1_score(y_val, y_pred, average='weighted') |
| |
| fold_scores['accuracy'].append(accuracy) |
| fold_scores['macro_f1'].append(macro_f1) |
| fold_scores['weighted_f1'].append(weighted_f1) |
| |
| print("\n分类报告:") |
| print(classification_report(y_val, y_pred)) |
| |
| |
| print("\n总体性能:") |
| print(f"平均准确率: {np.mean(fold_scores['accuracy']):.4f} ± {np.std(fold_scores['accuracy']):.4f}") |
| print(f"平均宏F1分数: {np.mean(fold_scores['macro_f1']):.4f} ± {np.std(fold_scores['macro_f1']):.4f}") |
| print(f"平均加权F1分数: {np.mean(fold_scores['weighted_f1']):.4f} ± {np.std(fold_scores['weighted_f1']):.4f}") |
| |
| def save_model(self, model_dir: str = "./models"): |
| """保存最终模型""" |
| if self.model is None: |
| raise ValueError("模型尚未训练") |
| |
| os.makedirs(model_dir, exist_ok=True) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| model_path = os.path.join(model_dir, f"topic_classifier_{timestamp}.joblib") |
| |
| joblib.dump(self.model, model_path) |
| print(f"\n模型已保存到: {model_path}") |
|
|
| def main(): |
| |
| classifier = TopicClassifier() |
| classifier.train_and_evaluate() |
| classifier.save_model() |
|
|
| if __name__ == "__main__": |
| main() |
|
|