gxm.wrappers#

Wrappers for gxm environments.

class ClipReward(env, unwrap=True, min=-1.0, max=1.0)#

Bases: Wrapper

Wrapper that clips the reward to a specified range.

clip(reward)#
Return type:

Array

env: Environment#
init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class Discretize(env, actions, unwrap=True)#

Bases: Wrapper

Wrapper that discretizes a continuous action space. Maps a discrete set of actions to the continuous action space of the environment. The actions are specified as a list of continuous actions \(A\). The action space of the wrapped environment is then \(\{0, 1, \ldots, |A|-1\}\).

>>> import gxm
>>> from gxm.wrappers import Discretize
>>> env = make("Gymnasium/Pendulum-v1")
>>> actions = jnp.array([-2.0, 0.0, 2.0])
>>> env = Discretize(env, actions)

The actions passed to the Discretize wrapper need to be of shape \((|A|, D)\), where \(|A|\) is the number of discrete actions and \(D\) is the dimensionality of the continuous action space of the wrapped environment.

actions: Any#
env: Environment#
init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class EpisodeCounter(env, unwrap=True)#

Bases: Wrapper

A wrapper that counts the number of episodes completed in the environment.

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EpisodeCounterState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EpisodeCounterState) – The current state of the environment.

Return type:

tuple[EpisodeCounterState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EpisodeCounterState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[EpisodeCounterState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class EpisodicLife(env)#

Bases: Wrapper[EpisodicLifeState]

A wrapper that makes losing a life in an environment (like Atari games) count as the end of an episode. It assumes that the environment’s timestep info dictionary contains a β€œlives” key indicating the number of lives remaining.

env: Environment#
init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EpisodicLifeState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EpisodicLifeState) – The current state of the environment.

Return type:

tuple[EpisodicLifeState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EpisodicLifeState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[EpisodicLifeState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class Evaluate(env, unwrap=True)#

Bases: Wrapper[EvaluateState]

env: Environment#
init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EvaluateState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EvaluateState) – The current state of the environment.

Return type:

tuple[EvaluateState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EvaluateState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[EvaluateState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class FlattenObservation(env, unwrap=True)#

Bases: Wrapper

Wrapper that adds a rollout method to the environment.

classmethod flatten(obs)#
Return type:

Array

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[WrapperState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (WrapperState) – The current state of the environment.

Return type:

tuple[WrapperState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (WrapperState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[WrapperState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class IgnoreTruncation(env, actions)#

Bases: Wrapper

A wrapper that treats truncation as termination and removes the corresponding obsercation from the timestep.

>>> import gxm
>>> from gxm.wrappers import IgnoreTruncation
>>> env = make("Gymnax/CartPole-v1")
>>> env = IgnoreTruncation(env)
env: Environment#
init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

  • action (Array) – The action to take in the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class RecordEpisodeStatistics(env, unwrap=True, gamma=1.0, n_episodes=1)#

Bases: Wrapper[RecordEpisodeStatisticsState]

A wrapper that records the episode length \(T\) , episodic return \(J(\tau) = \sum_{t=0}^{T} r_t\) , and discounted episodic return \(G(\tau) = \sum_{t=0}^{T} \gamma^t r_t\) at the end of each episode. The statistics can be accessed from the info field of the Timestep returned by the environment. It will contain the stats of the most recent finished episode. By default , the discount factor \(\gamma\) is set to 1.0, meaning that the episodic return and discounted episodic return are the same.

gamma: float#

The discount factor \(\gamma\) for calculating the discounted episodic return.

static get_averaged_stats(episode_stats)#
Return type:

dict[str, Array]

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[RecordEpisodeStatisticsState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

n_episodes: int#

The number of past episodes to record statistics for.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (RecordEpisodeStatisticsState) – The current state of the environment.

Return type:

tuple[RecordEpisodeStatisticsState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (RecordEpisodeStatisticsState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[RecordEpisodeStatisticsState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class Rollout(env)#

Bases: Wrapper

Wrapper that adds a rollout method to the environment.

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

rollout(key, env_state, pi, num_steps)#
Return type:

tuple[EnvironmentState, Trajectory]

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (EnvironmentState) – The current state of the environment.

  • action (Array) – The action to take in the environment.

Return type:

tuple[EnvironmentState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class StackObservations(env, n_stack, padding='reset')#

Bases: Wrapper[StackObservationsState]

Wrapper that stacks the observation along a new axis.

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[StackObservationsState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

num_stack: int#
padding: str#
reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (StackObservationsState) – The current state of the environment.

Return type:

tuple[StackObservationsState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (StackObservationsState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[StackObservationsState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class StepCounter(env, unwrap=True)#

Bases: Wrapper

A wrapper that counts the number of steps taken in the environment.

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[StepCounterState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (StepCounterState) – The current state of the environment.

Return type:

tuple[StepCounterState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (StepCounterState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[StepCounterState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class StickyAction(env, unwrap=True, stickiness=0.25)#

Bases: Wrapper

A wrapper that makes actions sticky with a given probability.

init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[StickyActionState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (StickyActionState) – The current state of the environment.

Return type:

tuple[StickyActionState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (StickyActionState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[StickyActionState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class TimeLimit(env, unwrap=True, time_limit=1000)#

Bases: Wrapper

Wrapper that terminates an episode after a fixed number of steps.

env: Environment#
init(key)#

Initialize the environment and return the initial state.

Parameters:

key (Array) – A JAX random key for any stochastic initialization.

Return type:

tuple[TimeLimitState, Timestep]

Returns:

A tuple containing the initial environment state and the initial timestep.

reset(key, env_state)#

Reset the environment to its initial state.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (TimeLimitState) – The current state of the environment.

Return type:

tuple[TimeLimitState, Timestep]

Returns:

A tuple containing the reset environment state and the initial timestep.

step(key, env_state, action)#

Perform a step in the environment given an action.

Parameters:
  • key (Array) – A JAX random key for any stochasticity in the environment.

  • env_state (TimeLimitState) – The current state of the environment.

  • action (Any) – The action to take in the environment.

Return type:

tuple[TimeLimitState, Timestep]

Returns:

A tuple containing the new environment state and the resulting timestep.

class Wrapper(env, unwrap=True)#

Bases: Generic[TWrapperState], Environment[TWrapperState]

Base class for environment wrappers in gxm.

env: Environment#
get_wrapper(wrapper_type)#

Retrieve the first wrapper of a specific type from the environment.

Parameters:

wrapper_type (type[Environment]) – The type of the wrapper to retrieve.

Return type:

Environment

Returns:

The first wrapper of the specified type.

Raises:

ValueError – If no wrapper of the specified type is found.

has_wrapper(wrapper_type)#

Check if the environment or any of its wrappers is of a specific type.

Parameters:

wrapper_type (type[Environment]) – The type to check for.

Return type:

bool

Returns:

True if the environment or any of its wrappers is of the specified type, False otherwise.

unwrap: bool = True#
property unwrapped: Environment#

Retrieve the base environment by unwrapping all wrappers.

Returns:

The base environment without any wrappers.