def __init__(self, env: py_environment.PyEnvironment, num_actions: np.ndarray): """Constructs a wrapper for discretizing the action space. **Note:** Only environments with a single BoundedArraySpec are supported. Args: env: Environment to wrap. num_actions: A np.array of the same shape as the environment's action_spec. Elements in the array specify the number of actions to discretize to for each dimension. Raises: ValueError: IF the action_spec shape and the limits shape are not equal. """ super(ActionDiscretizeWrapper, self).__init__(env) action_spec = tf.nest.flatten(env.action_spec()) if len(action_spec) != 1: raise ValueError( 'ActionDiscretizeWrapper only supports environments with a single ' 'action spec. Got {}'.format(env.action_spec())) action_spec = action_spec[0] self._original_spec = action_spec self._num_actions = np.broadcast_to(num_actions, action_spec.shape) if action_spec.shape != self._num_actions.shape: raise ValueError( 'Spec {} and limit shape do not match. Got {}'.format( action_spec, self._num_actions.shape)) self._discrete_spec, self._action_map = self._discretize_spec( action_spec, self._num_actions)
def __init__(self, env: py_environment.PyEnvironment, flat_dtype=None): """Creates a FlattenActionWrapper. Args: env: Environment to wrap. flat_dtype: Optional, if set to a np.dtype the flat action_spec uses this dtype. Raises: ValueError: If any of the action_spec shapes ndim > 1. ValueError: If dtypes differ across action specs and flat_dtype is not set. """ super(FlattenActionWrapper, self).__init__(env) self._original_action_spec = env.action_spec() flat_action_spec = tf.nest.flatten(env.action_spec()) if any([len(s.shape) > 1 for s in flat_action_spec]): raise ValueError('ActionSpec shapes should all have ndim == 1.') if flat_dtype is None and any( [s.dtype != flat_action_spec[0].dtype for s in flat_action_spec]): raise ValueError( 'All action_spec dtypes must match, or `flat_dtype` should be set.' ) # shape or 1 to handle scalar shapes (). shape = sum([(s.shape and s.shape[0]) or 1 for s in flat_action_spec]), if all([ isinstance(s, array_spec.BoundedArraySpec) for s in flat_action_spec ]): minimums = [ np.broadcast_to(s.minimum, shape=s.shape) for s in flat_action_spec ] maximums = [ np.broadcast_to(s.maximum, shape=s.shape) for s in flat_action_spec ] minimum = np.hstack(minimums) maximum = np.hstack(maximums) self._action_spec = array_spec.BoundedArraySpec( shape=shape, dtype=flat_dtype or flat_action_spec[0].dtype, minimum=minimum, maximum=maximum, name='FlattenedActionSpec') else: self._action_spec = array_spec.ArraySpec( shape=shape, dtype=flat_dtype or flat_action_spec[0].dtype, name='FlattenedActionSpec') self._flat_action_spec = flat_action_spec
def __init__( self, env: py_environment.PyEnvironment, history_length: int = 3, include_actions: bool = False, tile_first_step_obs: bool = False, ): """Initializes a HistoryWrapper. Args: env: Environment to wrap. history_length: Length of the history to attach. include_actions: Whether actions should be included in the history. tile_first_step_obs: If True the observation on reset is tiled to fill the history. """ super(HistoryWrapper, self).__init__(env) self._history_length = history_length self._include_actions = include_actions self._tile_first_step_obs = tile_first_step_obs self._zero_observation = self._zeros_from_spec(env.observation_spec()) self._zero_action = self._zeros_from_spec(env.action_spec()) self._observation_history = collections.deque(maxlen=history_length) self._action_history = collections.deque(maxlen=history_length) self._observation_spec = self._get_observation_spec()
def validate_py_environment( environment: py_environment.PyEnvironment, episodes: int = 5, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None): """Validates the environment follows the defined specs.""" time_step_spec = environment.time_step_spec() action_spec = environment.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec, observation_and_action_constraint_splitter=( observation_and_action_constraint_splitter)) if environment.batch_size is not None: batched_time_step_spec = array_spec.add_outer_dims_nest( time_step_spec, outer_dims=(environment.batch_size, )) else: batched_time_step_spec = time_step_spec episode_count = 0 time_step = environment.reset() while episode_count < episodes: if not array_spec.check_arrays_nest(time_step, batched_time_step_spec): raise ValueError('Given `time_step`: %r does not match expected ' '`time_step_spec`: %r' % (time_step, batched_time_step_spec)) action = random_policy.action(time_step).action time_step = environment.step(action) episode_count += np.sum(time_step.is_last())
def __init__(self, env: py_environment.PyEnvironment, num_extra_actions: int): """Initializes an instance of `ExtraDisabledActionsWrapper`. Args: env: The environment to wrap. num_extra_actions: The number of extra actions to add. """ super(ExtraDisabledActionsWrapper, self).__init__(env) orig_action_spec = env.action_spec() self._action_spec = array_spec.BoundedArraySpec( shape=orig_action_spec.shape, dtype=orig_action_spec.dtype, minimum=orig_action_spec.minimum, maximum=orig_action_spec.maximum + num_extra_actions) mask_spec = array_spec.ArraySpec( shape=[self._action_spec.maximum - self._action_spec.minimum + 1], dtype=np.int64) self._masked_observation_spec = (env.observation_spec(), mask_spec) self._constant_mask = np.array( [[1] * (orig_action_spec.maximum - orig_action_spec.minimum + 1) + [0] * num_extra_actions] * self.batch_size)
def validate_py_environment(environment: py_environment.PyEnvironment, episodes: int = 5): """Validates the environment follows the defined specs.""" time_step_spec = environment.time_step_spec() action_spec = environment.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) episode_count = 0 time_step = environment.reset() while episode_count < episodes: if not array_spec.check_arrays_nest(time_step, time_step_spec): raise ValueError( 'Given `time_step`: %r does not match expected `time_step_spec`: %r' % (time_step, time_step_spec)) action = random_policy.action(time_step).action time_step = environment.step(action) if time_step.is_last(): episode_count += 1 time_step = environment.reset()