| | import random |
| | import gradio as gr |
| | from PIL import Image |
| | from model import predict |
| | from datasets import load_dataset |
| |
|
| | dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train") |
| |
|
| | def classify_image(img: Image.Image): |
| | if img is None: |
| | return "No image uploaded", 0, {} |
| | |
| | label, confidence, probs = predict(img) |
| | return ( |
| | label, |
| | round(confidence, 3), |
| | {k: round(v, 3) for k, v in probs.items()} |
| | ) |
| |
|
| | |
| | def random_example(): |
| | item = random.choice(dataset) |
| | img = item["image"].convert("RGB") |
| | label = dataset.features["label"].int2str(item["label"]) |
| | |
| | return img, img, label |
| |
|
| | |
| | demo = gr.Blocks() |
| |
|
| | with demo: |
| | gr.Markdown("## Animal Image Classifier with Random Dataset Samples") |
| | |
| | with gr.Row(): |
| | input_img = gr.Image(type="pil", label="Upload an image") |
| | rand_img = gr.Button("Random Dataset Image") |
| | |
| | pred_btn = gr.Button("Predict") |
| | |
| | output_label = gr.Label(label="Predicted Class") |
| | output_conf = gr.Number(label="Confidence") |
| | output_probs = gr.JSON(label="All Probabilities") |
| | |
| | rand_display = gr.Image(type="pil", label="Random Dataset Sample") |
| | rand_label = gr.Textbox(label="Sample Label") |
| | |
| | |
| | pred_btn.click( |
| | classify_image, |
| | inputs=input_img, |
| | outputs=[output_label, output_conf, output_probs] |
| | ) |
| |
|
| | |
| | rand_img.click( |
| | random_example, |
| | outputs=[input_img, rand_display, rand_label] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |