| from transformers import CLIPFeatureExtractor |
| from safety_checker import StableDiffusionSafetyChecker |
| import torch |
| from PIL import Image |
| import gradio as gr |
| from pathlib import Path |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
| "CompVis/stable-diffusion-safety-checker" |
| ).to(device) |
| feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
| import gradio as gr |
|
|
|
|
| def image_classifier(files): |
| images = [Image.open(file).convert("RGB").resize((512, 512)) for file in files] |
|
|
| safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) |
| has_nsfw_concepts = safety_checker( |
| images=[images], clip_input=safety_checker_input.pixel_values.to(torch.float16) |
| ) |
| results = [ |
| {"has_nsfw": nsfw, "file": Path(file).name} |
| for (nsfw, file) in zip(has_nsfw_concepts, files) |
| ] |
| return {"results": results} |
|
|
|
|
| demo = gr.Interface( |
| title="Stable Diffusion Safety Checker API", |
| fn=image_classifier, |
| inputs=gr.File(file_count="multiple", file_types=["image"]), |
| outputs="json", |
| api_name="classify", |
| ) |
| demo.launch() |