File size: 7,973 Bytes
3496b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70651b8
 
afa9536
70651b8
 
6ffe733
3496b49
70651b8
 
 
 
 
 
 
a4c9320
70651b8
 
 
3496b49
 
70651b8
 
 
 
 
 
 
 
 
 
6ffe733
3496b49
afa9536
 
 
 
 
 
 
 
3496b49
afa9536
 
 
 
 
6d77e26
70651b8
 
 
3496b49
6ffe733
3496b49
 
70651b8
3496b49
 
 
 
 
 
 
 
 
70651b8
 
3496b49
70651b8
 
6ffe733
3496b49
 
70651b8
3496b49
70651b8
3496b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70651b8
 
 
 
 
 
 
 
 
 
 
3496b49
 
 
 
70651b8
3496b49
 
 
 
 
 
 
 
 
 
 
 
 
70651b8
3496b49
 
 
 
70651b8
 
 
3496b49
70651b8
 
 
 
3496b49
 
6d77e26
3496b49
70651b8
3496b49
70651b8
 
7ffdca2
70651b8
 
 
6ffe733
 
 
6d77e26
70651b8
 
 
3496b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""HFIE custom handler for lerobot folding_final (pi0.5).

Runs on HFIE's stock Python 3.11 container. lerobot 0.5.1 is 3.12-only
at the parser level (PEP 695 syntax), so __init__ patches the installed
package via the sibling patch_lerobot_for_py311.py before importing.

Required endpoint env vars:
  PIP_IGNORE_REQUIRES_PYTHON=1     # so pip will install lerobot==0.5.1
                                   # despite its requires-python>=3.12 pin
  HF_TOKEN=<token>                  # for gated google/paligemma-3b-pt-224
                                    # tokenizer used by the preprocessor

Wire format (mirrors `LocalPolicyClient.infer` in policy_relay.py):

  Input  {"inputs": {
              "images": {"left_wrist": <b64 jpeg>,
                          "right_wrist": <b64 jpeg>,
                          "base":       <b64 jpeg>},
              "state":  [16 floats, in radians],
              "prompt": "fold the towel"
          }}

  Output {"actions":         [[16 floats in radians] × T],  # absolute
          "actions_relative":[[16 floats]          × T],  # raw [-1, 1] model
          "chunk_len": T,
          "action_dim": 16,
          "action_feature_names": [...]}

The dataset records joints + gripper in DEGREES; our wire format is
RADIANS. The handler converts at the model boundary — state rad→deg
going in, action deg→rad coming out. Without this the normalizer baked
into the model's processors emits values like -42 rad for a gripper.
"""
import base64
import importlib.util
import io
import logging
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
from PIL import Image

log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

UPSTREAM_REPO = "lerobot-data-collection/folding_final"
DEFAULT_PROMPT = "fold the towel"
DEG_PER_RAD = 180.0 / np.pi
RAD_PER_DEG = np.pi / 180.0

ACTION_FEATURE_NAMES = [
    "right_joint_1.pos", "right_joint_2.pos", "right_joint_3.pos",
    "right_joint_4.pos", "right_joint_5.pos", "right_joint_6.pos",
    "right_joint_7.pos", "right_gripper.pos",
    "left_joint_1.pos", "left_joint_2.pos", "left_joint_3.pos",
    "left_joint_4.pos", "left_joint_5.pos", "left_joint_6.pos",
    "left_joint_7.pos", "left_gripper.pos",
]


def _patch_lerobot_in_place() -> None:
    """Run the bundled patch script in-process. Idempotent — safe on
    warm restarts. We import and call the script's main() rather than
    subprocess'ing because HFIE's container appears to strip site-
    packages from child Python processes — `find_spec("lerobot")`
    returns None there even though the parent has it installed."""
    here = Path(__file__).parent
    if str(here) not in sys.path:
        sys.path.insert(0, str(here))
    log.info("patching lerobot for Python 3.11 ...")
    # Quick diagnostic: where does this Python find lerobot?
    spec = importlib.util.find_spec("lerobot")
    log.info("lerobot find_spec: %s", spec.origin if spec else "<not found>")
    from patch_lerobot_for_py311 import main as _patch_main  # type: ignore
    _patch_main()


class EndpointHandler:
    def __init__(self, path: str = "") -> None:
        _patch_lerobot_in_place()

        # Imports DEFERRED until after patches land — importing earlier
        # would trip the 3.11 SyntaxError before we got a chance.
        from huggingface_hub import snapshot_download
        from lerobot.policies.pi05.modeling_pi05 import PI05Policy
        from lerobot.processor.pipeline import PolicyProcessorPipeline
        from lerobot.processor.relative_action_processor import (
            RelativeActionsProcessorStep,
        )
        from lerobot.processor.converters import (
            policy_action_to_transition,
            transition_to_policy_action,
        )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device_str = "cuda" if self.device.type == "cuda" else "cpu"
        log.info("device: %s", self.device)

        log.info("downloading %s ...", UPSTREAM_REPO)
        weights = snapshot_download(repo_id=UPSTREAM_REPO)
        log.info("loading PI05Policy ...")
        self.policy = (
            PI05Policy.from_pretrained(weights).to(self.device).eval()
        )

        log.info("loading processors ...")
        self.preproc = PolicyProcessorPipeline.from_pretrained(
            weights,
            config_filename="policy_preprocessor.json",
            overrides={"device_processor": {"device": device_str}},
        )
        rel_step = next(
            (s for s in self.preproc.steps
             if isinstance(s, RelativeActionsProcessorStep)),
            None,
        )
        self.postproc = PolicyProcessorPipeline.from_pretrained(
            weights,
            config_filename="policy_postprocessor.json",
            overrides={
                "absolute_actions_processor": {"relative_step": rel_step}
            },
            to_transition=policy_action_to_transition,
            to_output=transition_to_policy_action,
        )

        log.info("warming up ...")
        self._warmup()
        log.info("ready")

    @torch.inference_mode()
    def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
        inputs = data.get("inputs", data)
        prompt = str(inputs.get("prompt", DEFAULT_PROMPT))
        state = np.asarray(inputs["state"], dtype=np.float32)
        if state.shape != (16,):
            raise ValueError(f"state must be 16-dim, got shape {state.shape}")

        # Radians (wire) → degrees (model conditioning).
        state_deg = state * DEG_PER_RAD

        images_in = inputs.get("images") or {}
        batch = self._build_batch(images_in, state_deg, prompt)
        raw = self.policy.predict_action_chunk(batch)
        abs_deg = self.postproc(raw)

        actions_rel = raw.squeeze(0).detach().to("cpu").float().numpy()
        # Degrees (model) → radians (wire).
        actions_abs = (
            abs_deg.squeeze(0).detach().to("cpu").float().numpy() * RAD_PER_DEG
        )
        if actions_rel.shape[1] > 16:
            actions_rel = actions_rel[:, :16]
            actions_abs = actions_abs[:, :16]

        return {
            "actions": actions_abs.tolist(),
            "actions_relative": actions_rel.tolist(),
            "chunk_len": int(actions_abs.shape[0]),
            "action_dim": int(actions_abs.shape[1]),
            "action_feature_names": ACTION_FEATURE_NAMES,
        }

    def _build_batch(self, images_in, state_deg, prompt):
        batch: dict[str, Any] = {}
        for cam in ("left_wrist", "right_wrist", "base"):
            arr = self._decode_image(images_in.get(cam))
            if arr is None:
                arr = np.zeros((480, 640, 3), dtype=np.uint8)
            t = torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32) / 255.0
            batch[f"observation.images.{cam}"] = t.unsqueeze(0)
        batch["observation.state"] = torch.from_numpy(state_deg).unsqueeze(0)
        batch["task"] = prompt
        return self.preproc(batch)

    @staticmethod
    def _decode_image(b64):
        if not b64:
            return None
        try:
            return np.asarray(
                Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
            )
        except Exception as e:  # noqa: BLE001
            log.warning("image decode failed: %s", e)
            return None

    def _warmup(self) -> None:
        buf = io.BytesIO()
        Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)).save(
            buf, format="JPEG"
        )
        dummy = base64.b64encode(buf.getvalue()).decode("ascii")
        self({
            "inputs": {
                "images": {
                    "left_wrist": dummy,
                    "right_wrist": dummy,
                    "base": dummy,
                },
                "state": [0.0] * 16,
                "prompt": DEFAULT_PROMPT,
            }
        })