anomalyOS / api /startup.py
CaffeinatedCoding's picture
Upload folder using huggingface_hub
2ac1b86 verified
# api/startup.py
# All model and index loading happens here β€” once at FastAPI startup
# Everything stays in memory for the entire server lifetime
# Never load models per-request
import os
import json
import time
import sys
import torch
import clip
from src.patchcore import patchcore
from src.retriever import retriever
from src.graph import knowledge_graph
from src.depth import depth_estimator
from src.xai import gradcam, shap_explainer
from src.cache import inference_cache
from src.orchestrator import init_orchestrator
from api.logger import init_logger
# Startup timestamp β€” used for uptime calculation in /health
STARTUP_TIME = None
MODEL_VERSION = "v1.0"
def download_artifacts():
"""Download all required artifacts from HF Dataset at startup."""
from huggingface_hub import hf_hub_download, snapshot_download
import shutil
HF_REPO = "CaffeinatedCoding/anomalyos-logs"
token = os.environ.get("HF_TOKEN")
os.makedirs("data", exist_ok=True)
files_to_download = [
("models/pca_256.pkl", "data/pca_256.pkl"),
("models/midas_small.onnx", "data/midas_small.onnx"),
("configs/thresholds.json", "data/thresholds.json"),
("graph/knowledge_graph.json", "data/knowledge_graph.json"),
("indexes/index1_category.faiss", "data/index1_category.faiss"),
("indexes/index1_metadata.json", "data/index1_metadata.json"),
("indexes/index2_defect.faiss", "data/index2_defect.faiss"),
("indexes/index2_metadata.json", "data/index2_metadata.json"),
]
# Index 3 β€” one per category
categories = [
'bottle','cable','capsule','carpet','grid','hazelnut',
'leather','metal_nut','pill','screw','tile','toothbrush',
'transistor','wood','zipper'
]
for cat in categories:
files_to_download.append((
f"indexes/index3_{cat}.faiss",
f"data/index3_{cat}.faiss"
))
for repo_path, local_path in files_to_download:
if os.path.exists(local_path):
print(f"Already exists: {local_path}")
continue
try:
print(f"Downloading {repo_path}...")
downloaded = hf_hub_download(
repo_id=HF_REPO,
filename=repo_path,
repo_type="dataset",
token=token,
local_dir="/tmp/artifacts"
)
shutil.copy(downloaded, local_path)
print(f" β†’ {local_path}")
except Exception as e:
print(f" WARNING: Could not download {repo_path}: {e}")
def load_all():
"""
Called once from FastAPI lifespan on startup.
Order matters β€” patchcore before orchestrator, logger before anything logs.
"""
global STARTUP_TIME
STARTUP_TIME = time.time()
print("=" * 50)
print("AnomalyOS startup sequence")
print("=" * 50)
# Download artifacts first
download_artifacts()
# ── CPU thread tuning ─────────────────────────────────────
# HF Spaces CPU Basic = 2 vCPU
# Limit PyTorch threads to match β€” prevents over-subscription
torch.set_num_threads(2)
torch.set_default_dtype(torch.float32)
print(f"PyTorch threads: {torch.get_num_threads()}")
# ── Logger ────────────────────────────────────────────────
hf_token = os.environ.get("HF_TOKEN", "")
init_logger(hf_token)
# ── PatchCore extractor ───────────────────────────────────
patchcore.load()
# ── FAISS indexes ─────────────────────────────────────────
# Index 3 is lazy-loaded β€” not loaded here
retriever.load_indexes()
# ── Knowledge graph ───────────────────────────────────────
knowledge_graph.load()
# ── MiDaS depth estimator ─────────────────────────────────
try:
depth_estimator.load()
except FileNotFoundError as e:
print(f"WARNING: {e}")
print("Depth features will return zeros β€” inference continues")
# ── CLIP model ────────────────────────────────────────────
# Loaded here, injected into orchestrator
print("Loading CLIP ViT-B/32...", flush=True)
try:
print(" [Downloading CLIP weights...]", flush=True)
clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
print(" [CLIP weights loaded, setting eval mode...]", flush=True)
clip_model.eval()
print("CLIP loaded βœ“", flush=True)
except Exception as e:
print(f"ERROR loading CLIP: {e}", flush=True)
raise
# DEBUG: Aggressive output buffer flushing after CLIP
sys.stdout.write("[DEBUG] Point 1: After CLIP load\n")
sys.stdout.flush()
print("Loading thresholds...", flush=True)
sys.stdout.write("[DEBUG] Point 2: After thresholds print\n")
sys.stdout.flush()
sys.stdout.write("[DEBUG] Point 2a: Building thresholds path\n")
sys.stdout.flush()
thresholds_path = os.path.join(
os.environ.get("DATA_DIR", "data"), "thresholds.json"
)
sys.stdout.write(f"[DEBUG] Point 2b: Checking if {thresholds_path} exists\n")
sys.stdout.flush()
if os.path.exists(thresholds_path):
sys.stdout.write("[DEBUG] Point 2c: File exists, opening\n")
sys.stdout.flush()
with open(thresholds_path) as f:
sys.stdout.write("[DEBUG] Point 2d: File opened, loading JSON\n")
sys.stdout.flush()
thresholds = json.load(f)
print(f"Thresholds loaded βœ“ {len(thresholds)} categories", flush=True)
else:
thresholds = {}
print("WARNING: thresholds.json not found β€” using score > 0.5 fallback", flush=True)
sys.stdout.write("[DEBUG] Point 3: After thresholds loading\n")
sys.stdout.flush()
# ── GradCAM++ ─────────────────────────────────────────────
sys.stdout.write("[DEBUG] Point 4: Before GradCAM load\n")
sys.stdout.flush()
print("Loading GradCAM++...", flush=True)
try:
sys.stdout.write("[DEBUG] Point 4a: Inside GradCAM load try\n")
sys.stdout.flush()
gradcam.load()
print("GradCAM++ loaded βœ“", flush=True)
except Exception as e:
print(f"WARNING: GradCAM++ load failed: {e}", flush=True)
print("Forensics mode will run without GradCAM++", flush=True)
sys.stdout.write("[DEBUG] Point 5: After GradCAM load\n")
sys.stdout.flush()
# ── SHAP background ───────────────────────────────────────
sys.stdout.write("[DEBUG] Point 6: Before SHAP load\n")
sys.stdout.flush()
print("Loading SHAP background...", flush=True)
sys.stdout.write("[DEBUG] Point 6a: After SHAP print\n")
sys.stdout.flush()
bg_path = os.path.join(
os.environ.get("DATA_DIR", "data"), "shap_background.npy"
)
try:
if os.path.exists(bg_path):
sys.stdout.write("[DEBUG] Point 6b: SHAP file exists, loading\n")
sys.stdout.flush()
shap_explainer.load_background(bg_path)
print("SHAP background loaded βœ“", flush=True)
else:
print(f"WARNING: SHAP background not found at {bg_path}", flush=True)
print("SHAP explanations will use default background", flush=True)
except Exception as e:
print(f"WARNING: SHAP background load failed: {e}", flush=True)
print("SHAP explanations will use default background", flush=True)
# ── Inject into orchestrator ──────────────────────────────
sys.stdout.write("[DEBUG] Point 7: Before orchestrator init\n")
sys.stdout.flush()
print("Initializing orchestrator...", flush=True)
sys.stdout.write("[DEBUG] Point 7a: About to call init_orchestrator\n")
sys.stdout.flush()
try:
init_orchestrator(clip_model, clip_preprocess, thresholds)
sys.stdout.write("[DEBUG] Point 7b: init_orchestrator returned\n")
sys.stdout.flush()
print("Orchestrator initialized βœ“", flush=True)
except Exception as e:
print(f"ERROR initializing orchestrator: {e}", flush=True)
raise
sys.stdout.write("[DEBUG] Point 8: After orchestrator init β€” about to print completion\n")
sys.stdout.flush()
elapsed = time.time() - STARTUP_TIME
print("=" * 50, flush=True)
print(f"Startup complete in {elapsed:.1f}s βœ“", flush=True)
print(f"Model version: {MODEL_VERSION}", flush=True)
print("=" * 50, flush=True)
return {
"clip_model": clip_model,
"clip_preprocess": clip_preprocess,
"thresholds": thresholds
}
def get_uptime() -> float:
if STARTUP_TIME is None:
return 0.0
return time.time() - STARTUP_TIME