Spaces:
Running
Running
| # api/startup.py | |
| # All model and index loading happens here β once at FastAPI startup | |
| # Everything stays in memory for the entire server lifetime | |
| # Never load models per-request | |
| import os | |
| import json | |
| import time | |
| import sys | |
| import torch | |
| import clip | |
| from src.patchcore import patchcore | |
| from src.retriever import retriever | |
| from src.graph import knowledge_graph | |
| from src.depth import depth_estimator | |
| from src.xai import gradcam, shap_explainer | |
| from src.cache import inference_cache | |
| from src.orchestrator import init_orchestrator | |
| from api.logger import init_logger | |
| # Startup timestamp β used for uptime calculation in /health | |
| STARTUP_TIME = None | |
| MODEL_VERSION = "v1.0" | |
| def download_artifacts(): | |
| """Download all required artifacts from HF Dataset at startup.""" | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import shutil | |
| HF_REPO = "CaffeinatedCoding/anomalyos-logs" | |
| token = os.environ.get("HF_TOKEN") | |
| os.makedirs("data", exist_ok=True) | |
| files_to_download = [ | |
| ("models/pca_256.pkl", "data/pca_256.pkl"), | |
| ("models/midas_small.onnx", "data/midas_small.onnx"), | |
| ("configs/thresholds.json", "data/thresholds.json"), | |
| ("graph/knowledge_graph.json", "data/knowledge_graph.json"), | |
| ("indexes/index1_category.faiss", "data/index1_category.faiss"), | |
| ("indexes/index1_metadata.json", "data/index1_metadata.json"), | |
| ("indexes/index2_defect.faiss", "data/index2_defect.faiss"), | |
| ("indexes/index2_metadata.json", "data/index2_metadata.json"), | |
| ] | |
| # Index 3 β one per category | |
| categories = [ | |
| 'bottle','cable','capsule','carpet','grid','hazelnut', | |
| 'leather','metal_nut','pill','screw','tile','toothbrush', | |
| 'transistor','wood','zipper' | |
| ] | |
| for cat in categories: | |
| files_to_download.append(( | |
| f"indexes/index3_{cat}.faiss", | |
| f"data/index3_{cat}.faiss" | |
| )) | |
| for repo_path, local_path in files_to_download: | |
| if os.path.exists(local_path): | |
| print(f"Already exists: {local_path}") | |
| continue | |
| try: | |
| print(f"Downloading {repo_path}...") | |
| downloaded = hf_hub_download( | |
| repo_id=HF_REPO, | |
| filename=repo_path, | |
| repo_type="dataset", | |
| token=token, | |
| local_dir="/tmp/artifacts" | |
| ) | |
| shutil.copy(downloaded, local_path) | |
| print(f" β {local_path}") | |
| except Exception as e: | |
| print(f" WARNING: Could not download {repo_path}: {e}") | |
| def load_all(): | |
| """ | |
| Called once from FastAPI lifespan on startup. | |
| Order matters β patchcore before orchestrator, logger before anything logs. | |
| """ | |
| global STARTUP_TIME | |
| STARTUP_TIME = time.time() | |
| print("=" * 50) | |
| print("AnomalyOS startup sequence") | |
| print("=" * 50) | |
| # Download artifacts first | |
| download_artifacts() | |
| # ββ CPU thread tuning βββββββββββββββββββββββββββββββββββββ | |
| # HF Spaces CPU Basic = 2 vCPU | |
| # Limit PyTorch threads to match β prevents over-subscription | |
| torch.set_num_threads(2) | |
| torch.set_default_dtype(torch.float32) | |
| print(f"PyTorch threads: {torch.get_num_threads()}") | |
| # ββ Logger ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| init_logger(hf_token) | |
| # ββ PatchCore extractor βββββββββββββββββββββββββββββββββββ | |
| patchcore.load() | |
| # ββ FAISS indexes βββββββββββββββββββββββββββββββββββββββββ | |
| # Index 3 is lazy-loaded β not loaded here | |
| retriever.load_indexes() | |
| # ββ Knowledge graph βββββββββββββββββββββββββββββββββββββββ | |
| knowledge_graph.load() | |
| # ββ MiDaS depth estimator βββββββββββββββββββββββββββββββββ | |
| try: | |
| depth_estimator.load() | |
| except FileNotFoundError as e: | |
| print(f"WARNING: {e}") | |
| print("Depth features will return zeros β inference continues") | |
| # ββ CLIP model ββββββββββββββββββββββββββββββββββββββββββββ | |
| # Loaded here, injected into orchestrator | |
| print("Loading CLIP ViT-B/32...", flush=True) | |
| try: | |
| print(" [Downloading CLIP weights...]", flush=True) | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu") | |
| print(" [CLIP weights loaded, setting eval mode...]", flush=True) | |
| clip_model.eval() | |
| print("CLIP loaded β", flush=True) | |
| except Exception as e: | |
| print(f"ERROR loading CLIP: {e}", flush=True) | |
| raise | |
| # DEBUG: Aggressive output buffer flushing after CLIP | |
| sys.stdout.write("[DEBUG] Point 1: After CLIP load\n") | |
| sys.stdout.flush() | |
| print("Loading thresholds...", flush=True) | |
| sys.stdout.write("[DEBUG] Point 2: After thresholds print\n") | |
| sys.stdout.flush() | |
| sys.stdout.write("[DEBUG] Point 2a: Building thresholds path\n") | |
| sys.stdout.flush() | |
| thresholds_path = os.path.join( | |
| os.environ.get("DATA_DIR", "data"), "thresholds.json" | |
| ) | |
| sys.stdout.write(f"[DEBUG] Point 2b: Checking if {thresholds_path} exists\n") | |
| sys.stdout.flush() | |
| if os.path.exists(thresholds_path): | |
| sys.stdout.write("[DEBUG] Point 2c: File exists, opening\n") | |
| sys.stdout.flush() | |
| with open(thresholds_path) as f: | |
| sys.stdout.write("[DEBUG] Point 2d: File opened, loading JSON\n") | |
| sys.stdout.flush() | |
| thresholds = json.load(f) | |
| print(f"Thresholds loaded β {len(thresholds)} categories", flush=True) | |
| else: | |
| thresholds = {} | |
| print("WARNING: thresholds.json not found β using score > 0.5 fallback", flush=True) | |
| sys.stdout.write("[DEBUG] Point 3: After thresholds loading\n") | |
| sys.stdout.flush() | |
| # ββ GradCAM++ βββββββββββββββββββββββββββββββββββββββββββββ | |
| sys.stdout.write("[DEBUG] Point 4: Before GradCAM load\n") | |
| sys.stdout.flush() | |
| print("Loading GradCAM++...", flush=True) | |
| try: | |
| sys.stdout.write("[DEBUG] Point 4a: Inside GradCAM load try\n") | |
| sys.stdout.flush() | |
| gradcam.load() | |
| print("GradCAM++ loaded β", flush=True) | |
| except Exception as e: | |
| print(f"WARNING: GradCAM++ load failed: {e}", flush=True) | |
| print("Forensics mode will run without GradCAM++", flush=True) | |
| sys.stdout.write("[DEBUG] Point 5: After GradCAM load\n") | |
| sys.stdout.flush() | |
| # ββ SHAP background βββββββββββββββββββββββββββββββββββββββ | |
| sys.stdout.write("[DEBUG] Point 6: Before SHAP load\n") | |
| sys.stdout.flush() | |
| print("Loading SHAP background...", flush=True) | |
| sys.stdout.write("[DEBUG] Point 6a: After SHAP print\n") | |
| sys.stdout.flush() | |
| bg_path = os.path.join( | |
| os.environ.get("DATA_DIR", "data"), "shap_background.npy" | |
| ) | |
| try: | |
| if os.path.exists(bg_path): | |
| sys.stdout.write("[DEBUG] Point 6b: SHAP file exists, loading\n") | |
| sys.stdout.flush() | |
| shap_explainer.load_background(bg_path) | |
| print("SHAP background loaded β", flush=True) | |
| else: | |
| print(f"WARNING: SHAP background not found at {bg_path}", flush=True) | |
| print("SHAP explanations will use default background", flush=True) | |
| except Exception as e: | |
| print(f"WARNING: SHAP background load failed: {e}", flush=True) | |
| print("SHAP explanations will use default background", flush=True) | |
| # ββ Inject into orchestrator ββββββββββββββββββββββββββββββ | |
| sys.stdout.write("[DEBUG] Point 7: Before orchestrator init\n") | |
| sys.stdout.flush() | |
| print("Initializing orchestrator...", flush=True) | |
| sys.stdout.write("[DEBUG] Point 7a: About to call init_orchestrator\n") | |
| sys.stdout.flush() | |
| try: | |
| init_orchestrator(clip_model, clip_preprocess, thresholds) | |
| sys.stdout.write("[DEBUG] Point 7b: init_orchestrator returned\n") | |
| sys.stdout.flush() | |
| print("Orchestrator initialized β", flush=True) | |
| except Exception as e: | |
| print(f"ERROR initializing orchestrator: {e}", flush=True) | |
| raise | |
| sys.stdout.write("[DEBUG] Point 8: After orchestrator init β about to print completion\n") | |
| sys.stdout.flush() | |
| elapsed = time.time() - STARTUP_TIME | |
| print("=" * 50, flush=True) | |
| print(f"Startup complete in {elapsed:.1f}s β", flush=True) | |
| print(f"Model version: {MODEL_VERSION}", flush=True) | |
| print("=" * 50, flush=True) | |
| return { | |
| "clip_model": clip_model, | |
| "clip_preprocess": clip_preprocess, | |
| "thresholds": thresholds | |
| } | |
| def get_uptime() -> float: | |
| if STARTUP_TIME is None: | |
| return 0.0 | |
| return time.time() - STARTUP_TIME | |