| """HFIE custom handler for lerobot folding_final (pi0.5). |
| |
| Runs on HFIE's stock Python 3.11 container. lerobot 0.5.1 is 3.12-only |
| at the parser level (PEP 695 syntax), so __init__ patches the installed |
| package via the sibling patch_lerobot_for_py311.py before importing. |
| |
| Required endpoint env vars: |
| PIP_IGNORE_REQUIRES_PYTHON=1 # so pip will install lerobot==0.5.1 |
| # despite its requires-python>=3.12 pin |
| HF_TOKEN=<token> # for gated google/paligemma-3b-pt-224 |
| # tokenizer used by the preprocessor |
| |
| Wire format (mirrors `LocalPolicyClient.infer` in policy_relay.py): |
| |
| Input {"inputs": { |
| "images": {"left_wrist": <b64 jpeg>, |
| "right_wrist": <b64 jpeg>, |
| "base": <b64 jpeg>}, |
| "state": [16 floats, in radians], |
| "prompt": "fold the towel" |
| }} |
| |
| Output {"actions": [[16 floats in radians] × T], # absolute |
| "actions_relative":[[16 floats] × T], # raw [-1, 1] model |
| "chunk_len": T, |
| "action_dim": 16, |
| "action_feature_names": [...]} |
| |
| The dataset records joints + gripper in DEGREES; our wire format is |
| RADIANS. The handler converts at the model boundary — state rad→deg |
| going in, action deg→rad coming out. Without this the normalizer baked |
| into the model's processors emits values like -42 rad for a gripper. |
| """ |
| import base64 |
| import importlib.util |
| import io |
| import logging |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| log = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
| UPSTREAM_REPO = "lerobot-data-collection/folding_final" |
| DEFAULT_PROMPT = "fold the towel" |
| DEG_PER_RAD = 180.0 / np.pi |
| RAD_PER_DEG = np.pi / 180.0 |
|
|
| ACTION_FEATURE_NAMES = [ |
| "right_joint_1.pos", "right_joint_2.pos", "right_joint_3.pos", |
| "right_joint_4.pos", "right_joint_5.pos", "right_joint_6.pos", |
| "right_joint_7.pos", "right_gripper.pos", |
| "left_joint_1.pos", "left_joint_2.pos", "left_joint_3.pos", |
| "left_joint_4.pos", "left_joint_5.pos", "left_joint_6.pos", |
| "left_joint_7.pos", "left_gripper.pos", |
| ] |
|
|
|
|
| def _patch_lerobot_in_place() -> None: |
| """Run the bundled patch script in-process. Idempotent — safe on |
| warm restarts. We import and call the script's main() rather than |
| subprocess'ing because HFIE's container appears to strip site- |
| packages from child Python processes — `find_spec("lerobot")` |
| returns None there even though the parent has it installed.""" |
| here = Path(__file__).parent |
| if str(here) not in sys.path: |
| sys.path.insert(0, str(here)) |
| log.info("patching lerobot for Python 3.11 ...") |
| |
| spec = importlib.util.find_spec("lerobot") |
| log.info("lerobot find_spec: %s", spec.origin if spec else "<not found>") |
| from patch_lerobot_for_py311 import main as _patch_main |
| _patch_main() |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = "") -> None: |
| _patch_lerobot_in_place() |
|
|
| |
| |
| from huggingface_hub import snapshot_download |
| from lerobot.policies.pi05.modeling_pi05 import PI05Policy |
| from lerobot.processor.pipeline import PolicyProcessorPipeline |
| from lerobot.processor.relative_action_processor import ( |
| RelativeActionsProcessorStep, |
| ) |
| from lerobot.processor.converters import ( |
| policy_action_to_transition, |
| transition_to_policy_action, |
| ) |
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| device_str = "cuda" if self.device.type == "cuda" else "cpu" |
| log.info("device: %s", self.device) |
|
|
| log.info("downloading %s ...", UPSTREAM_REPO) |
| weights = snapshot_download(repo_id=UPSTREAM_REPO) |
| log.info("loading PI05Policy ...") |
| self.policy = ( |
| PI05Policy.from_pretrained(weights).to(self.device).eval() |
| ) |
|
|
| log.info("loading processors ...") |
| self.preproc = PolicyProcessorPipeline.from_pretrained( |
| weights, |
| config_filename="policy_preprocessor.json", |
| overrides={"device_processor": {"device": device_str}}, |
| ) |
| rel_step = next( |
| (s for s in self.preproc.steps |
| if isinstance(s, RelativeActionsProcessorStep)), |
| None, |
| ) |
| self.postproc = PolicyProcessorPipeline.from_pretrained( |
| weights, |
| config_filename="policy_postprocessor.json", |
| overrides={ |
| "absolute_actions_processor": {"relative_step": rel_step} |
| }, |
| to_transition=policy_action_to_transition, |
| to_output=transition_to_policy_action, |
| ) |
|
|
| log.info("warming up ...") |
| self._warmup() |
| log.info("ready") |
|
|
| @torch.inference_mode() |
| def __call__(self, data: dict[str, Any]) -> dict[str, Any]: |
| inputs = data.get("inputs", data) |
| prompt = str(inputs.get("prompt", DEFAULT_PROMPT)) |
| state = np.asarray(inputs["state"], dtype=np.float32) |
| if state.shape != (16,): |
| raise ValueError(f"state must be 16-dim, got shape {state.shape}") |
|
|
| |
| state_deg = state * DEG_PER_RAD |
|
|
| images_in = inputs.get("images") or {} |
| batch = self._build_batch(images_in, state_deg, prompt) |
| raw = self.policy.predict_action_chunk(batch) |
| abs_deg = self.postproc(raw) |
|
|
| actions_rel = raw.squeeze(0).detach().to("cpu").float().numpy() |
| |
| actions_abs = ( |
| abs_deg.squeeze(0).detach().to("cpu").float().numpy() * RAD_PER_DEG |
| ) |
| if actions_rel.shape[1] > 16: |
| actions_rel = actions_rel[:, :16] |
| actions_abs = actions_abs[:, :16] |
|
|
| return { |
| "actions": actions_abs.tolist(), |
| "actions_relative": actions_rel.tolist(), |
| "chunk_len": int(actions_abs.shape[0]), |
| "action_dim": int(actions_abs.shape[1]), |
| "action_feature_names": ACTION_FEATURE_NAMES, |
| } |
|
|
| def _build_batch(self, images_in, state_deg, prompt): |
| batch: dict[str, Any] = {} |
| for cam in ("left_wrist", "right_wrist", "base"): |
| arr = self._decode_image(images_in.get(cam)) |
| if arr is None: |
| arr = np.zeros((480, 640, 3), dtype=np.uint8) |
| t = torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32) / 255.0 |
| batch[f"observation.images.{cam}"] = t.unsqueeze(0) |
| batch["observation.state"] = torch.from_numpy(state_deg).unsqueeze(0) |
| batch["task"] = prompt |
| return self.preproc(batch) |
|
|
| @staticmethod |
| def _decode_image(b64): |
| if not b64: |
| return None |
| try: |
| return np.asarray( |
| Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
| ) |
| except Exception as e: |
| log.warning("image decode failed: %s", e) |
| return None |
|
|
| def _warmup(self) -> None: |
| buf = io.BytesIO() |
| Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)).save( |
| buf, format="JPEG" |
| ) |
| dummy = base64.b64encode(buf.getvalue()).decode("ascii") |
| self({ |
| "inputs": { |
| "images": { |
| "left_wrist": dummy, |
| "right_wrist": dummy, |
| "base": dummy, |
| }, |
| "state": [0.0] * 16, |
| "prompt": DEFAULT_PROMPT, |
| } |
| }) |
|
|