| import gradio as gr |
| import base64 |
| import json |
| import os |
| from PIL import Image |
| import io |
| from handler import EndpointHandler |
|
|
| |
| print("Initializing MobileCLIP handler...") |
| try: |
| handler = EndpointHandler() |
| print(f"Handler initialized successfully! Device: {handler.device}") |
| except Exception as e: |
| print(f"Error initializing handler: {e}") |
| handler = None |
|
|
| def classify_image(image, top_k=10): |
| """ |
| Main classification function for public interface. |
| """ |
| if handler is None: |
| return "Error: Handler not initialized", None |
| |
| if image is None: |
| return "Please upload an image", None |
| |
| try: |
| |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| img_b64 = base64.b64encode(buffered.getvalue()).decode() |
| |
| |
| result = handler({ |
| "inputs": { |
| "image": img_b64, |
| "top_k": int(top_k) |
| } |
| }) |
| |
| |
| if isinstance(result, list): |
| |
| output_text = "**Top {} Classifications:**\n\n".format(len(result)) |
| |
| |
| chart_data = [] |
| |
| for i, item in enumerate(result, 1): |
| score_pct = item['score'] * 100 |
| output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n" |
| chart_data.append((item['label'], item['score'])) |
| |
| return output_text, chart_data |
| else: |
| return f"Error: {result.get('error', 'Unknown error')}", None |
| |
| except Exception as e: |
| return f"Error: {str(e)}", None |
|
|
| def upsert_labels_admin(admin_token, new_items_json): |
| """ |
| Admin function to add new labels. |
| """ |
| if handler is None: |
| return "Error: Handler not initialized" |
| |
| if not admin_token: |
| return "Error: Admin token required" |
| |
| try: |
| |
| items = json.loads(new_items_json) if new_items_json else [] |
| |
| result = handler({ |
| "inputs": { |
| "op": "upsert_labels", |
| "token": admin_token, |
| "items": items |
| } |
| }) |
| |
| if result.get("status") == "ok": |
| return f"β
Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}" |
| elif result.get("error") == "unauthorized": |
| return "β Error: Invalid admin token" |
| else: |
| return f"β Error: {result.get('detail', result.get('error', 'Unknown error'))}" |
| |
| except json.JSONDecodeError: |
| return "β Error: Invalid JSON format" |
| except Exception as e: |
| return f"β Error: {str(e)}" |
|
|
| def reload_labels_admin(admin_token, version): |
| """ |
| Admin function to reload a specific label version. |
| """ |
| if handler is None: |
| return "Error: Handler not initialized" |
| |
| if not admin_token: |
| return "Error: Admin token required" |
| |
| try: |
| result = handler({ |
| "inputs": { |
| "op": "reload_labels", |
| "token": admin_token, |
| "version": int(version) if version else 1 |
| } |
| }) |
| |
| if result.get("status") == "ok": |
| return f"β
Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}" |
| elif result.get("status") == "nochange": |
| return f"βΉοΈ No change needed. Current version: {result.get('labels_version', 'unknown')}" |
| elif result.get("error") == "unauthorized": |
| return "β Error: Invalid admin token" |
| elif result.get("error") == "invalid_version": |
| return "β Error: Invalid version number" |
| else: |
| return f"β Error: {result.get('error', 'Unknown error')}" |
| |
| except Exception as e: |
| return f"β Error: {str(e)}" |
|
|
| def get_current_stats(): |
| """ |
| Get current label statistics. |
| """ |
| if handler is None: |
| return "Handler not initialized" |
| |
| try: |
| num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0 |
| version = getattr(handler, 'labels_version', 1) |
| device = handler.device if hasattr(handler, 'device') else "unknown" |
| |
| stats = f""" |
| **Current Statistics:** |
| - Number of labels: {num_labels} |
| - Labels version: {version} |
| - Device: {device} |
| - Model: MobileCLIP-B |
| """ |
| |
| if hasattr(handler, 'class_names') and len(handler.class_names) > 0: |
| stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}" |
| if len(handler.class_names) > 5: |
| stats += "..." |
| |
| return stats |
| except Exception as e: |
| return f"Error getting stats: {str(e)}" |
|
|
| |
| print("Creating Gradio interface...") |
| with gr.Blocks(title="MobileCLIP Image Classifier") as demo: |
| gr.Markdown(""" |
| # πΌοΈ MobileCLIP-B Zero-Shot Image Classifier |
| |
| Upload an image to classify it using MobileCLIP-B model with dynamic label management. |
| """) |
| |
| with gr.Tab("π Image Classification"): |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image( |
| type="pil", |
| label="Upload Image" |
| ) |
| top_k_slider = gr.Slider( |
| minimum=1, |
| maximum=50, |
| value=10, |
| step=1, |
| label="Number of top results to show" |
| ) |
| classify_btn = gr.Button("π Classify Image", variant="primary") |
| |
| with gr.Column(): |
| output_text = gr.Markdown(label="Classification Results") |
| |
| output_chart = gr.Dataframe( |
| headers=["Label", "Confidence"], |
| label="Classification Scores", |
| interactive=False |
| ) |
| |
| |
| classify_btn.click( |
| fn=classify_image, |
| inputs=[input_image, top_k_slider], |
| outputs=[output_text, output_chart] |
| ) |
| |
| |
| input_image.change( |
| fn=classify_image, |
| inputs=[input_image, top_k_slider], |
| outputs=[output_text, output_chart] |
| ) |
| |
| with gr.Tab("π§ Admin Panel"): |
| gr.Markdown(""" |
| ### Admin Functions |
| **Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`) |
| """) |
| |
| with gr.Row(): |
| admin_token_input = gr.Textbox( |
| label="Admin Token", |
| type="password", |
| placeholder="Enter admin token" |
| ) |
| |
| with gr.Accordion("π Current Statistics", open=True): |
| stats_display = gr.Markdown(value=get_current_stats()) |
| refresh_stats_btn = gr.Button("π Refresh Stats") |
| refresh_stats_btn.click( |
| fn=get_current_stats, |
| inputs=[], |
| outputs=stats_display |
| ) |
| |
| with gr.Accordion("β Add New Labels", open=False): |
| gr.Markdown(""" |
| Add new labels by providing JSON array: |
| ```json |
| [ |
| {"id": 100, "name": "new_object", "prompt": "a photo of a new_object"}, |
| {"id": 101, "name": "another_object", "prompt": "a photo of another_object"} |
| ] |
| ``` |
| """) |
| new_items_input = gr.Code( |
| label="New Items JSON", |
| language="json", |
| lines=5, |
| value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]' |
| ) |
| upsert_btn = gr.Button("β Add Labels", variant="primary") |
| upsert_output = gr.Markdown() |
| |
| upsert_btn.click( |
| fn=upsert_labels_admin, |
| inputs=[admin_token_input, new_items_input], |
| outputs=upsert_output |
| ) |
| |
| with gr.Accordion("π Reload Label Version", open=False): |
| gr.Markdown("Reload labels from a specific version stored in the Hub") |
| version_input = gr.Number( |
| label="Version Number", |
| value=1, |
| precision=0 |
| ) |
| reload_btn = gr.Button("π Reload Version", variant="primary") |
| reload_output = gr.Markdown() |
| |
| reload_btn.click( |
| fn=reload_labels_admin, |
| inputs=[admin_token_input, version_input], |
| outputs=reload_output |
| ) |
| |
| with gr.Tab("βΉοΈ About"): |
| gr.Markdown(""" |
| ## About MobileCLIP-B Classifier |
| |
| This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification. |
| |
| ### Features: |
| - π **Fast inference**: < 30ms on GPU |
| - π·οΈ **Dynamic labels**: Add/update labels without redeployment |
| - π **Version control**: Track and reload label versions |
| - π **Visual results**: Classification scores and confidence |
| |
| ### Environment Variables (set in Space Settings): |
| - `ADMIN_TOKEN`: Secret token for admin operations |
| - `HF_LABEL_REPO`: Hub repository for label storage |
| - `HF_WRITE_TOKEN`: Token with write permissions to label repo |
| - `HF_READ_TOKEN`: Token with read permissions (optional) |
| |
| ### Model Details: |
| - **Architecture**: MobileCLIP-B with MobileOne blocks |
| - **Text Encoder**: Transformer-based, 77 token context |
| - **Image Size**: 224x224 |
| - **Embedding Dim**: 512 |
| |
| ### License: |
| Model weights are licensed under Apple Sample Code License (ASCL). |
| """) |
|
|
| print("Gradio interface created successfully!") |
|
|
| if __name__ == "__main__": |
| print("Launching Gradio app...") |
| demo.launch() |