JAX¶
The JAX backend is experimental.
What This Backend Changes¶
rlmesh.jax keeps the same environment, model, and sandbox behavior as the shared RLMesh client
APIs, but decodes tensor leaves to JAX arrays. Space wrappers returned from JAX clients also sample
JAX-compatible values.
Install it with:
pip install "rlmesh[jax]"
Concrete API |
Shared behavior |
Backend-specific behavior |
|---|---|---|
|
Remote Environments single clients |
Observations, actions, and render frames use arrays. |
|
Remote Environments vector clients |
Batched values use JAX-compatible containers. |
|
|
|
|
Sandbox single sandbox sessions |
Owned sandbox client is |
|
Sandbox vector sandbox sessions |
Owned sandbox client is |
Conversion Semantics¶
asarray(tensor)imports over DLPack. XLA shares RLMesh’s 64-byte-aligned buffers zero-copy and copies otherwise; JAX arrays are immutable either way, so there is no mutation hazard.from_array(array)moves the array to CPU if needed, blocks until ready, and copies the elements into a fresh RLMesh tensor.int64,uint64, andfloat64values require JAX 64-bit mode (jax.config.update("jax_enable_x64", True)); without it JAX itself demotes those dtypes.Requires
jax >= 0.4.24, the first release with DLPackboolsupport.ensure_availableenforces the floor at runtime.
Value Helpers¶
- rlmesh.jax.ensure_available()[source]¶
Raise if JAX is not installed or is older than the supported floor.
- Return type:
None
- rlmesh.jax.asarray(tensor)[source]¶
Return a JAX array for an RLMesh tensor.
- Parameters:
tensor (Tensor) – RLMesh tensor value to convert.
- Returns:
JAX array imported over DLPack. XLA shares 64-byte-aligned buffers and copies otherwise; either way the result is immutable.
- Return type:
object
- rlmesh.jax.from_array(array)[source]¶
Encode a JAX array as an RLMesh value.
- Parameters:
array (object) – JAX array to encode.
- Returns:
Tensor for non-scalar arrays, or a primitive for scalar values.
- Return type:
Tensor | None | bool | int | float | str | bytes
- rlmesh.jax.space_from_spec(spec)[source]¶
Create a JAX-adapted space wrapper for a native space spec.
- Parameters:
spec (SpaceSpec)
- Return type:
Space[None | bool | int | float | str | bytes | object | list[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]] | tuple[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], …] | dict[str, None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]]
RemoteEnv¶
- final class rlmesh.jax.RemoteEnv[source]¶
Bases:
RemoteEnvBase[None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue],None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue]]Experimental JAX-backed remote client for one environment.
Tensor leaves decode to JAX arrays while Python primitives and nested containers are preserved.
- Parameters:
address – Endpoint address such as
"tcp://127.0.0.1:5555".host – TCP host helper used when
addressis omitted.port – TCP port helper used when
addressis omitted.path – Unix socket path helper used when
addressis omitted.transport – Explicit transport selector.
RemoteVectorEnv¶
- final class rlmesh.jax.RemoteVectorEnv[source]¶
Bases:
RemoteVectorEnvBase[None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue],None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue]]Experimental JAX-backed remote client for vectorized environments.
- Parameters:
address – Endpoint address such as
"tcp://127.0.0.1:5555".host – TCP host helper used when
addressis omitted.port – TCP port helper used when
addressis omitted.path – Unix socket path helper used when
addressis omitted.transport – Explicit transport selector.
Model¶
- final class rlmesh.jax.Model[source]¶
Bases:
ModelBase[None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue],None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue]]Experimental JAX-backed model:
predictworks in JAX values.The JAX-typed
ModelBase; see it for the source/spec construction andrun(env, seeds=[...]) -> RunResulteval.
Sandbox¶
- final class rlmesh.jax.SandboxEnv[source]¶
Bases:
SandboxEnvBase[None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue],None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue]]Experimental JAX-backed owned sandbox session for one environment.
- Parameters:
source – Gymnasium id, explicit
gym://source, or pinned environment source.base_image – Optional Docker base image override.
rlmesh_package – Optional RLMesh package, wheel, or
"local"installed in the sandbox.packages – Extra environment packages installed in the sandbox.
imports – Import names checked during sandbox startup.
trust_remote_code – Allow remote environment code to execute.
allow_unpinned_hf – Allow Hugging Face sources without a pinned revision.
**gym_make_kwargs – Keyword arguments forwarded to environment creation.
- final class rlmesh.jax.SandboxVectorEnv[source]¶
Bases:
SandboxVectorEnvBase[None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue],None|bool|int|float|str|bytes|object|list[JaxValue] |tuple[JaxValue, …] |dict[str, JaxValue]]Experimental JAX-backed owned sandbox session for vectorized environments.
- Parameters:
source – Gymnasium id, explicit
gym://source, or pinned environment source.num_envs – Number of environment instances to create.
vectorization_mode – Vectorization mode requested inside the sandbox.
base_image – Optional Docker base image override.
rlmesh_package – Optional RLMesh package, wheel, or
"local"installed in the sandbox.packages – Extra environment packages installed in the sandbox.
imports – Import names checked during sandbox startup.
trust_remote_code – Allow remote environment code to execute.
allow_unpinned_hf – Allow Hugging Face sources without a pinned revision.
**env_make_kwargs – Keyword arguments forwarded to environment creation.