| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Sandboxed Python code executor for the REPL environment. |
| | |
| | Uses smolagents.LocalPythonExecutor as the backend for battle-tested sandboxed |
| | execution, with RLM-specific features on top: |
| | - Context loading (set_context) |
| | - Variable access (get_variable, list_variables) |
| | - Function injection (inject_function for llm_query, llm_query_batched) |
| | - Output capped at 8,192 characters per turn (configurable) |
| | - Persistent namespace across code blocks |
| | """ |
| |
|
| | import json |
| | import logging |
| | import time |
| | import traceback |
| | from collections.abc import Callable |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from smolagents import LocalPythonExecutor |
| |
|
| | logger = logging.getLogger(__name__) |
| | logger.addHandler(logging.NullHandler()) |
| |
|
| |
|
| | class PythonExecutor: |
| | """Sandboxed Python code executor with persistent namespace. |
| | |
| | Wraps smolagents.LocalPythonExecutor with RLM-specific features: |
| | - Context loading for RLM tasks |
| | - Variable tracking for observation |
| | - Function injection for llm_query, llm_query_batched |
| | - Configurable output length limit (default 8192 chars per Prime Intellect) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | max_output_length: int = 8192, |
| | allowed_imports: Optional[List[str]] = None, |
| | ): |
| | """Initialize the executor. |
| | |
| | Args: |
| | max_output_length: Maximum characters for stdout/stderr (default 8192) |
| | allowed_imports: List of allowed module names for import |
| | |
| | Note: |
| | smolagents.LocalPythonExecutor does NOT support wall-clock timeouts. |
| | Instead, it limits operations (10M ops) and while iterations (1M). |
| | """ |
| | self.max_output_length = max_output_length |
| |
|
| | |
| | default_imports = [ |
| | "re", |
| | "json", |
| | "math", |
| | "random", |
| | "collections", |
| | "itertools", |
| | "functools", |
| | "operator", |
| | "string", |
| | "textwrap", |
| | "difflib", |
| | "statistics", |
| | "decimal", |
| | "fractions", |
| | "datetime", |
| | "copy", |
| | "pprint", |
| | "typing", |
| | "dataclasses", |
| | "enum", |
| | "bisect", |
| | "heapq", |
| | "array", |
| | "struct", |
| | "base64", |
| | "hashlib", |
| | "hmac", |
| | "uuid", |
| | ] |
| |
|
| | self.allowed_imports = allowed_imports or default_imports |
| |
|
| | |
| | self._executor = LocalPythonExecutor( |
| | additional_authorized_imports=self.allowed_imports |
| | ) |
| |
|
| | |
| | self._user_variables: set[str] = set() |
| |
|
| | |
| | self._callable_tools: Dict[str, Callable[..., Any]] = {} |
| |
|
| | |
| | self._register_helpers() |
| |
|
| | def _register_helpers(self) -> None: |
| | """Register helper functions with the executor.""" |
| | helpers = { |
| | "format_exc": traceback.format_exc, |
| | "safe_json_dumps": lambda obj: json.dumps( |
| | obj, default=lambda o: repr(o) |
| | ), |
| | } |
| | |
| | for name, func in helpers.items(): |
| | self.inject_function(name, func) |
| |
|
| | def _sync_callable_tools(self) -> None: |
| | """Sync callable functions with the executor via send_tools.""" |
| | if self._callable_tools: |
| | try: |
| | |
| | self._executor.send_tools(self._callable_tools) |
| | except Exception: |
| | logger.debug( |
| | "send_tools failed; continuing without extra tools", |
| | exc_info=True, |
| | ) |
| |
|
| | def set_context(self, context: str, variable_name: str = "context") -> None: |
| | """Load context into namespace as a variable. |
| | |
| | Args: |
| | context: The context string to load |
| | variable_name: Name of the variable (default "context") |
| | """ |
| | self.set_variable(variable_name, context) |
| |
|
| | def set_variable(self, name: str, value: Any) -> None: |
| | """Set a variable in the namespace. |
| | |
| | Args: |
| | name: Variable name |
| | value: Variable value |
| | """ |
| | |
| | if hasattr(self._executor, "state"): |
| | self._executor.state[name] = value |
| | else: |
| | |
| | self._executor._injected_vars = getattr( |
| | self._executor, "_injected_vars", {} |
| | ) |
| | self._executor._injected_vars[name] = value |
| |
|
| | self._user_variables.add(name) |
| |
|
| | def get_variable(self, name: str) -> Optional[Any]: |
| | """Retrieve a variable from namespace. |
| | |
| | Args: |
| | name: Variable name |
| | |
| | Returns: |
| | The variable value or None if not found |
| | """ |
| | |
| | if hasattr(self._executor, "state"): |
| | return self._executor.state.get(name) |
| |
|
| | |
| | if hasattr(self._executor, "_injected_vars"): |
| | return self._executor._injected_vars.get(name) |
| |
|
| | return None |
| |
|
| | def list_variables(self) -> List[str]: |
| | """List non-private variables in namespace. |
| | |
| | Returns: |
| | List of variable names (excluding private and builtins) |
| | """ |
| | variables = set() |
| |
|
| | |
| | if hasattr(self._executor, "state"): |
| | for key in self._executor.state: |
| | if not key.startswith("_"): |
| | variables.add(key) |
| |
|
| | |
| | variables.update(self._user_variables) |
| |
|
| | return list(variables) |
| |
|
| | def execute(self, code: str) -> Dict[str, Any]: |
| | """Execute Python code and return results. |
| | |
| | Args: |
| | code: Python code to execute |
| | |
| | Returns: |
| | Dictionary with stdout, stderr, locals_snapshot, execution_time, |
| | success, and exception fields |
| | """ |
| | start_time = time.time() |
| | success = True |
| | exception_msg = None |
| | new_locals: Dict[str, str] = {} |
| |
|
| | |
| | pre_state_keys = set() |
| | if hasattr(self._executor, "state"): |
| | pre_state_keys = set(self._executor.state.keys()) |
| |
|
| | stdout_parts: list[str] = [] |
| | stderr_parts: list[str] = [] |
| |
|
| | try: |
| | exec_result = self._executor(code) |
| |
|
| | |
| | try: |
| | logs = getattr(exec_result, "logs", None) |
| | if logs: |
| | stdout_parts.append(str(logs)) |
| | except Exception: |
| | logger.debug("Failed to read exec_result.logs", exc_info=True) |
| |
|
| | |
| | try: |
| | if hasattr(exec_result, "output"): |
| | out_val = exec_result.output |
| | if out_val is not None: |
| | try: |
| | stdout_parts.append(json.dumps(out_val)) |
| | except Exception: |
| | stdout_parts.append(repr(out_val)) |
| | except Exception: |
| | logger.debug("Failed to read exec_result.output", exc_info=True) |
| |
|
| | |
| | try: |
| | err = getattr(exec_result, "error", None) |
| | if err: |
| | stderr_parts.append(str(err)) |
| | success = False |
| | exception_msg = str(err) |
| | except Exception: |
| | logger.debug("Failed to read exec_result.error", exc_info=True) |
| |
|
| | try: |
| | ex = getattr(exec_result, "exception", None) |
| | if ex: |
| | stderr_parts.append(str(ex)) |
| | success = False |
| | exception_msg = str(ex) |
| | except Exception: |
| | logger.debug( |
| | "Failed to read exec_result.exception", exc_info=True |
| | ) |
| |
|
| | |
| | try: |
| | if hasattr(exec_result, "exit_code"): |
| | if ( |
| | exec_result.exit_code is not None |
| | and exec_result.exit_code != 0 |
| | ): |
| | success = False |
| | elif hasattr(exec_result, "success"): |
| | success = bool(exec_result.success) |
| | except Exception: |
| | logger.debug( |
| | "Failed to determine exec_result exit code", exc_info=True |
| | ) |
| |
|
| | except Exception as e: |
| | success = False |
| | exception_msg = ( |
| | f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" |
| | ) |
| | stderr_parts.append(exception_msg) |
| |
|
| | execution_time = time.time() - start_time |
| |
|
| | |
| | if hasattr(self._executor, "state"): |
| | for key in self._executor.state: |
| | if key not in pre_state_keys and not key.startswith("_"): |
| | try: |
| | val = self._executor.state[key] |
| | val_repr = repr(val) |
| | if len(val_repr) > 500: |
| | val_repr = val_repr[:500] + "..." |
| | new_locals[key] = val_repr |
| | self._user_variables.add(key) |
| | except Exception: |
| | new_locals[key] = "<unrepresentable>" |
| |
|
| | |
| | stdout = "\n".join(part for part in stdout_parts if part) |
| | stderr = "\n".join(part for part in stderr_parts if part) |
| |
|
| | |
| | if len(stdout) > self.max_output_length: |
| | stdout = ( |
| | stdout[: self.max_output_length] |
| | + f"\n... (truncated, total {len(stdout)} chars)" |
| | ) |
| |
|
| | if len(stderr) > self.max_output_length: |
| | stderr = ( |
| | stderr[: self.max_output_length] |
| | + f"\n... (truncated, total {len(stderr)} chars)" |
| | ) |
| |
|
| | return { |
| | "stdout": stdout, |
| | "stderr": stderr, |
| | "locals_snapshot": new_locals, |
| | "execution_time": execution_time, |
| | "success": success, |
| | "exception": exception_msg, |
| | } |
| |
|
| | def reset(self) -> None: |
| | """Reset namespace to initial state.""" |
| | |
| | self._executor = LocalPythonExecutor( |
| | additional_authorized_imports=self.allowed_imports |
| | ) |
| | self._user_variables.clear() |
| | self._callable_tools.clear() |
| | self._register_helpers() |
| |
|
| | def inject_function(self, name: str, func: Callable[..., Any]) -> None: |
| | """Inject a callable function into the namespace. |
| | |
| | Used for adding llm_query, llm_query_batched, FINAL, etc. |
| | |
| | Args: |
| | name: Function name in namespace |
| | func: The callable to inject |
| | """ |
| | |
| | self._callable_tools[name] = func |
| | self._user_variables.add(name) |
| | self._sync_callable_tools() |
| |
|