| from typing import Dict, List, Any |
| import onnxruntime as ort |
| import numpy as np |
| from PIL import Image |
| import io |
| import base64 |
| import os |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| model_path = path if path else "." |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
| self.encoder = ort.InferenceSession( |
| os.path.join(model_path, "edge_sam_3x_encoder.onnx"), |
| providers=providers |
| ) |
| self.decoder = ort.InferenceSession( |
| os.path.join(model_path, "edge_sam_3x_decoder.onnx"), |
| providers=providers |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| try: |
| |
| inputs = data.get("inputs", data) |
| params = data.get("parameters", {}) |
|
|
| |
| if isinstance(inputs, str): |
| image = Image.open(io.BytesIO(base64.b64decode(inputs))) |
| else: |
| image = inputs |
|
|
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| image = image.resize((1024, 1024), Image.BILINEAR) |
| img_array = np.array(image).astype(np.float32) / 255.0 |
| img_array = img_array.transpose(2, 0, 1)[np.newaxis, :] |
|
|
| |
| embeddings = self.encoder.run(None, {'image': img_array})[0] |
|
|
| |
| coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32) |
| labels = np.array(params.get("point_labels", [1]), dtype=np.float32) |
|
|
| |
| decoder_outputs = self.decoder.run(None, { |
| 'image_embeddings': embeddings, |
| 'point_coords': coords.reshape(1, -1, 2), |
| 'point_labels': labels.reshape(1, -1) |
| }) |
|
|
| |
| |
| masks = decoder_outputs[1] |
|
|
| |
| mask = masks[0, 0] |
| mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR) |
| mask = np.array(mask) |
| mask = (mask > 0.0).astype(np.uint8) * 255 |
|
|
| |
| result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)} |
|
|
| if params.get("return_mask_image", True): |
| buffer = io.BytesIO() |
| Image.fromarray(mask, mode='L').save(buffer, format='PNG') |
| result["mask"] = base64.b64encode(buffer.getvalue()).decode() |
|
|
| return [result] |
|
|
| except Exception as e: |
| import traceback |
| return [{ |
| "error": str(e), |
| "type": type(e).__name__, |
| "traceback": traceback.format_exc() |
| }] |
|
|