Spaces:
Build error
Build error
| from IPython.display import display, JSON | |
| import matplotlib.pyplot as plt | |
| from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet | |
| import numpy as np | |
| import time | |
| import gradio as gr | |
| import json | |
| import cv2 | |
| import os | |
| # ------------------------------------------------------ | |
| # LOAD MODEL | |
| # ------------------------------------------------------ | |
| print("Default SpeciesNet model:", DEFAULT_MODEL) | |
| print("Supported SpeciesNet models:", SUPPORTED_MODELS) | |
| model = SpeciesNet(DEFAULT_MODEL) | |
| # ------------------------------------------------------ | |
| # VALIDATION FUNCTIONS | |
| # ------------------------------------------------------ | |
| def validate_predictions_structure(pred): | |
| """ | |
| Validate internal structure for both detection and classification. | |
| This ensures correct keys exist and formats are valid. | |
| """ | |
| required_keys = ["filepath", "detections", "classifications"] | |
| for key in required_keys: | |
| if key not in pred: | |
| raise ValueError(f" Missing key '{key}' in prediction block") | |
| # --- Validate detections (list of dicts) --- | |
| if not isinstance(pred["detections"], list): | |
| raise ValueError(" detections must be a list") | |
| for det in pred["detections"]: | |
| if not all(k in det for k in ["bbox", "conf", "label"]): | |
| raise ValueError(" Each detection must contain bbox, conf, label") | |
| if len(det["bbox"]) != 4: | |
| raise ValueError(" bbox must be [x, y, w, h]") | |
| # --- Validate classifications --- | |
| cls = pred["classifications"] | |
| if not isinstance(cls, dict): | |
| raise ValueError(" classifications must be a dictionary") | |
| for key in ["classes", "scores"]: | |
| if key not in cls: | |
| raise ValueError(f" classifications missing '{key}'") | |
| if len(cls["classes"]) != len(cls["scores"]): | |
| raise ValueError(" classes and scores length mismatch") | |
| return True | |
| def validate_model_output(predictions_dict): | |
| """ | |
| Validates entire output returned by SpeciesNet before visualization. | |
| """ | |
| if "predictions" not in predictions_dict: | |
| raise ValueError(" Output missing top-level 'predictions' key") | |
| if not isinstance(predictions_dict["predictions"], list): | |
| raise ValueError(" 'predictions' must be a list") | |
| print(f" Total prediction entries: {len(predictions_dict['predictions'])}") | |
| # Validate each prediction block | |
| for i, pred in enumerate(predictions_dict["predictions"]): | |
| print(f"\n--- Checking prediction #{i+1} ---") | |
| validate_predictions_structure(pred) | |
| print("\n Output format validated successfully!\n") | |
| # ------------------------------------------------------ | |
| # VISUALIZATION | |
| # ------------------------------------------------------ | |
| def draw_predictions(image_path, predictions_dict): | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| raise ValueError(f"Could not load image: {image_path}") | |
| img_h, img_w, _ = img.shape | |
| for pred in predictions_dict.get("predictions", []): | |
| detections = pred.get("detections", []) | |
| classifications = pred.get("classifications", {}) | |
| classes = classifications.get("classes", []) | |
| scores = classifications.get("scores", []) | |
| top_class_name = None | |
| top_score = None | |
| if len(classes) > 0: | |
| top_class_name = classes[0].split(";")[-1] | |
| top_score = scores[0] | |
| # SKIP NON-ANIMALS | |
| if len(classes) == 0: | |
| continue | |
| taxon = classes[0].lower() | |
| if not ("mammalia" in taxon or "aves" in taxon): | |
| continue | |
| for det in detections: | |
| bbox = det["bbox"] | |
| conf = det["conf"] | |
| label = det["label"] | |
| x, y, w, h = bbox | |
| x1 = int(x * img_w) | |
| y1 = int(y * img_h) | |
| x2 = int((x + w) * img_w) | |
| y2 = int((y + h) * img_h) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3) | |
| detection_text = f"{label} ({conf:.2f})" | |
| classification_text = ( | |
| f"{top_class_name} ({top_score:.2f})" if top_class_name else "" | |
| ) | |
| text_lines = [] | |
| if classification_text: | |
| text_lines.append(classification_text) | |
| text_lines.append(detection_text) | |
| total_text_height = 0 | |
| text_widths = [] | |
| for line in text_lines: | |
| (text_w, text_h), _ = cv2.getTextSize( | |
| line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 | |
| ) | |
| total_text_height += text_h + 5 | |
| text_widths.append(text_w) | |
| max_text_width = max(text_widths) | |
| cv2.rectangle( | |
| img, | |
| (x1, max(y1 - total_text_height - 10, 0)), | |
| (x1 + max_text_width + 10, y1), | |
| (0, 255, 0), | |
| -1, | |
| ) | |
| y_text = y1 - 5 | |
| for line in text_lines[::-1]: | |
| cv2.putText( | |
| img, | |
| line, | |
| (x1 + 5, y_text), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 0, 0), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| (_, text_h), _ = cv2.getTextSize( | |
| line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 | |
| ) | |
| y_text -= text_h + 5 | |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # ------------------------------------------------------ | |
| # INFERENCE FUNCTION | |
| # ------------------------------------------------------ | |
| def inference(image): | |
| filepath = "temp_image.jpg" | |
| image.save(filepath) | |
| start = time.time() | |
| predictions_dict = model.predict( | |
| instances_dict={ | |
| "instances": [ | |
| { | |
| "filepath": filepath, | |
| # "country": "VNM", | |
| } | |
| ] | |
| } | |
| ) | |
| end = time.time() | |
| print(f"\n⏱ Inference Time: {end - start:.2f} sec") | |
| # --- Validate format --- | |
| validate_model_output(predictions_dict) | |
| # --- Save JSON --- | |
| with open("last_output.json", "w") as f: | |
| json.dump(predictions_dict, f, indent=4) | |
| print(" Saved JSON to last_output.json\n") | |
| # --- Draw Visualization --- | |
| annotated_image = draw_predictions(filepath, predictions_dict) | |
| pretty_json = json.dumps(predictions_dict, indent=4) | |
| return annotated_image, pretty_json | |
| # ------------------------------------------------------ | |
| # GRADIO UI | |
| # ------------------------------------------------------ | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Image(label="Detection + Classification Output"), | |
| gr.JSON(label="Raw Model Output"), | |
| ], | |
| title=" SpeciesNet Wildlife Detector + Classifier", | |
| description="Upload a wildlife camera image.", | |
| ) | |
| iface.launch() | |