folding-final-handler / handler.py
riklo's picture
handler: patch lerobot in-process (subprocess loses site-packages on HFIE)
afa9536 verified
"""HFIE custom handler for lerobot folding_final (pi0.5).
Runs on HFIE's stock Python 3.11 container. lerobot 0.5.1 is 3.12-only
at the parser level (PEP 695 syntax), so __init__ patches the installed
package via the sibling patch_lerobot_for_py311.py before importing.
Required endpoint env vars:
PIP_IGNORE_REQUIRES_PYTHON=1 # so pip will install lerobot==0.5.1
# despite its requires-python>=3.12 pin
HF_TOKEN=<token> # for gated google/paligemma-3b-pt-224
# tokenizer used by the preprocessor
Wire format (mirrors `LocalPolicyClient.infer` in policy_relay.py):
Input {"inputs": {
"images": {"left_wrist": <b64 jpeg>,
"right_wrist": <b64 jpeg>,
"base": <b64 jpeg>},
"state": [16 floats, in radians],
"prompt": "fold the towel"
}}
Output {"actions": [[16 floats in radians] × T], # absolute
"actions_relative":[[16 floats] × T], # raw [-1, 1] model
"chunk_len": T,
"action_dim": 16,
"action_feature_names": [...]}
The dataset records joints + gripper in DEGREES; our wire format is
RADIANS. The handler converts at the model boundary — state rad→deg
going in, action deg→rad coming out. Without this the normalizer baked
into the model's processors emits values like -42 rad for a gripper.
"""
import base64
import importlib.util
import io
import logging
import sys
from pathlib import Path
from typing import Any
import numpy as np
import torch
from PIL import Image
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
UPSTREAM_REPO = "lerobot-data-collection/folding_final"
DEFAULT_PROMPT = "fold the towel"
DEG_PER_RAD = 180.0 / np.pi
RAD_PER_DEG = np.pi / 180.0
ACTION_FEATURE_NAMES = [
"right_joint_1.pos", "right_joint_2.pos", "right_joint_3.pos",
"right_joint_4.pos", "right_joint_5.pos", "right_joint_6.pos",
"right_joint_7.pos", "right_gripper.pos",
"left_joint_1.pos", "left_joint_2.pos", "left_joint_3.pos",
"left_joint_4.pos", "left_joint_5.pos", "left_joint_6.pos",
"left_joint_7.pos", "left_gripper.pos",
]
def _patch_lerobot_in_place() -> None:
"""Run the bundled patch script in-process. Idempotent — safe on
warm restarts. We import and call the script's main() rather than
subprocess'ing because HFIE's container appears to strip site-
packages from child Python processes — `find_spec("lerobot")`
returns None there even though the parent has it installed."""
here = Path(__file__).parent
if str(here) not in sys.path:
sys.path.insert(0, str(here))
log.info("patching lerobot for Python 3.11 ...")
# Quick diagnostic: where does this Python find lerobot?
spec = importlib.util.find_spec("lerobot")
log.info("lerobot find_spec: %s", spec.origin if spec else "<not found>")
from patch_lerobot_for_py311 import main as _patch_main # type: ignore
_patch_main()
class EndpointHandler:
def __init__(self, path: str = "") -> None:
_patch_lerobot_in_place()
# Imports DEFERRED until after patches land — importing earlier
# would trip the 3.11 SyntaxError before we got a chance.
from huggingface_hub import snapshot_download
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.processor.relative_action_processor import (
RelativeActionsProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_str = "cuda" if self.device.type == "cuda" else "cpu"
log.info("device: %s", self.device)
log.info("downloading %s ...", UPSTREAM_REPO)
weights = snapshot_download(repo_id=UPSTREAM_REPO)
log.info("loading PI05Policy ...")
self.policy = (
PI05Policy.from_pretrained(weights).to(self.device).eval()
)
log.info("loading processors ...")
self.preproc = PolicyProcessorPipeline.from_pretrained(
weights,
config_filename="policy_preprocessor.json",
overrides={"device_processor": {"device": device_str}},
)
rel_step = next(
(s for s in self.preproc.steps
if isinstance(s, RelativeActionsProcessorStep)),
None,
)
self.postproc = PolicyProcessorPipeline.from_pretrained(
weights,
config_filename="policy_postprocessor.json",
overrides={
"absolute_actions_processor": {"relative_step": rel_step}
},
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
log.info("warming up ...")
self._warmup()
log.info("ready")
@torch.inference_mode()
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
inputs = data.get("inputs", data)
prompt = str(inputs.get("prompt", DEFAULT_PROMPT))
state = np.asarray(inputs["state"], dtype=np.float32)
if state.shape != (16,):
raise ValueError(f"state must be 16-dim, got shape {state.shape}")
# Radians (wire) → degrees (model conditioning).
state_deg = state * DEG_PER_RAD
images_in = inputs.get("images") or {}
batch = self._build_batch(images_in, state_deg, prompt)
raw = self.policy.predict_action_chunk(batch)
abs_deg = self.postproc(raw)
actions_rel = raw.squeeze(0).detach().to("cpu").float().numpy()
# Degrees (model) → radians (wire).
actions_abs = (
abs_deg.squeeze(0).detach().to("cpu").float().numpy() * RAD_PER_DEG
)
if actions_rel.shape[1] > 16:
actions_rel = actions_rel[:, :16]
actions_abs = actions_abs[:, :16]
return {
"actions": actions_abs.tolist(),
"actions_relative": actions_rel.tolist(),
"chunk_len": int(actions_abs.shape[0]),
"action_dim": int(actions_abs.shape[1]),
"action_feature_names": ACTION_FEATURE_NAMES,
}
def _build_batch(self, images_in, state_deg, prompt):
batch: dict[str, Any] = {}
for cam in ("left_wrist", "right_wrist", "base"):
arr = self._decode_image(images_in.get(cam))
if arr is None:
arr = np.zeros((480, 640, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32) / 255.0
batch[f"observation.images.{cam}"] = t.unsqueeze(0)
batch["observation.state"] = torch.from_numpy(state_deg).unsqueeze(0)
batch["task"] = prompt
return self.preproc(batch)
@staticmethod
def _decode_image(b64):
if not b64:
return None
try:
return np.asarray(
Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
)
except Exception as e: # noqa: BLE001
log.warning("image decode failed: %s", e)
return None
def _warmup(self) -> None:
buf = io.BytesIO()
Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)).save(
buf, format="JPEG"
)
dummy = base64.b64encode(buf.getvalue()).decode("ascii")
self({
"inputs": {
"images": {
"left_wrist": dummy,
"right_wrist": dummy,
"base": dummy,
},
"state": [0.0] * 16,
"prompt": DEFAULT_PROMPT,
}
})