JAX

The JAX backend is experimental.

What This Backend Changes

rlmesh.jax keeps the same environment, model, and sandbox behavior as the shared RLMesh client APIs, but decodes tensor leaves to JAX arrays. Space wrappers returned from JAX clients also sample JAX-compatible values.

Install it with:

pip install "rlmesh[jax]"

Concrete API

Shared behavior

Backend-specific behavior

rlmesh.jax.RemoteEnv

Remote Environments single clients

Observations, actions, and render frames use arrays.

rlmesh.jax.RemoteVectorEnv

Remote Environments vector clients

Batched values use JAX-compatible containers.

rlmesh.jax.Model

Models

predict_fn receives JAX-decoded observations.

rlmesh.jax.SandboxEnv

Sandbox single sandbox sessions

Owned sandbox client is rlmesh.jax.RemoteEnv.

rlmesh.jax.SandboxVectorEnv

Sandbox vector sandbox sessions

Owned sandbox client is rlmesh.jax.RemoteVectorEnv.

Conversion Semantics

  • asarray(tensor) imports over DLPack. XLA shares RLMesh’s 64-byte-aligned buffers zero-copy and copies otherwise; JAX arrays are immutable either way, so there is no mutation hazard.

  • from_array(array) moves the array to CPU if needed, blocks until ready, and copies the elements into a fresh RLMesh tensor.

  • int64, uint64, and float64 values require JAX 64-bit mode (jax.config.update("jax_enable_x64", True)); without it JAX itself demotes those dtypes.

  • Requires jax >= 0.4.24, the first release with DLPack bool support. ensure_available enforces the floor at runtime.

Value Helpers

rlmesh.jax.ensure_available()[source]

Raise if JAX is not installed or is older than the supported floor.

Return type:

None

rlmesh.jax.asarray(tensor)[source]

Return a JAX array for an RLMesh tensor.

Parameters:

tensor (Tensor) – RLMesh tensor value to convert.

Returns:

JAX array imported over DLPack. XLA shares 64-byte-aligned buffers and copies otherwise; either way the result is immutable.

Return type:

object

rlmesh.jax.from_array(array)[source]

Encode a JAX array as an RLMesh value.

Parameters:

array (object) – JAX array to encode.

Returns:

Tensor for non-scalar arrays, or a primitive for scalar values.

Return type:

Tensor | None | bool | int | float | str | bytes

rlmesh.jax.space_from_spec(spec)[source]

Create a JAX-adapted space wrapper for a native space spec.

Parameters:

spec (SpaceSpec)

Return type:

Space[None | bool | int | float | str | bytes | object | list[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]] | tuple[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], …] | dict[str, None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]]

RemoteEnv

final class rlmesh.jax.RemoteEnv[source]

Bases: RemoteEnvBase[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]

Experimental JAX-backed remote client for one environment.

Tensor leaves decode to JAX arrays while Python primitives and nested containers are preserved.

Parameters:
  • address – Endpoint address such as "tcp://127.0.0.1:5555".

  • host – TCP host helper used when address is omitted.

  • port – TCP port helper used when address is omitted.

  • path – Unix socket path helper used when address is omitted.

  • transport – Explicit transport selector.

RemoteVectorEnv

final class rlmesh.jax.RemoteVectorEnv[source]

Bases: RemoteVectorEnvBase[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]

Experimental JAX-backed remote client for vectorized environments.

Parameters:
  • address – Endpoint address such as "tcp://127.0.0.1:5555".

  • host – TCP host helper used when address is omitted.

  • port – TCP port helper used when address is omitted.

  • path – Unix socket path helper used when address is omitted.

  • transport – Explicit transport selector.

Model

final class rlmesh.jax.Model[source]

Bases: ModelBase[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]

Experimental JAX-backed model: predict works in JAX values.

The JAX-typed ModelBase; see it for the source/spec construction and run(env, seeds=[...]) -> RunResult eval.

Sandbox

final class rlmesh.jax.SandboxEnv[source]

Bases: SandboxEnvBase[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]

Experimental JAX-backed owned sandbox session for one environment.

Parameters:
  • source – Gymnasium id, explicit gym:// source, or pinned environment source.

  • base_image – Optional Docker base image override.

  • rlmesh_package – Optional RLMesh package, wheel, or "local" installed in the sandbox.

  • packages – Extra environment packages installed in the sandbox.

  • imports – Import names checked during sandbox startup.

  • trust_remote_code – Allow remote environment code to execute.

  • allow_unpinned_hf – Allow Hugging Face sources without a pinned revision.

  • **gym_make_kwargs – Keyword arguments forwarded to environment creation.

final class rlmesh.jax.SandboxVectorEnv[source]

Bases: SandboxVectorEnvBase[None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue], None | bool | int | float | str | bytes | object | list[JaxValue] | tuple[JaxValue, …] | dict[str, JaxValue]]

Experimental JAX-backed owned sandbox session for vectorized environments.

Parameters:
  • source – Gymnasium id, explicit gym:// source, or pinned environment source.

  • num_envs – Number of environment instances to create.

  • vectorization_mode – Vectorization mode requested inside the sandbox.

  • base_image – Optional Docker base image override.

  • rlmesh_package – Optional RLMesh package, wheel, or "local" installed in the sandbox.

  • packages – Extra environment packages installed in the sandbox.

  • imports – Import names checked during sandbox startup.

  • trust_remote_code – Allow remote environment code to execute.

  • allow_unpinned_hf – Allow Hugging Face sources without a pinned revision.

  • **env_make_kwargs – Keyword arguments forwarded to environment creation.