Exemplo n.º 1
0
    def __init__(self, env, idx):
        """Creates an observation filter wrapper.

    Args:
      env: Environment to wrap.
      idx: Array of indexes pointing to elements to include in output.

    Raises:
      ValueError: If observation spec is nested.
      ValueError: If indexes are not single-dimensional.
      ValueError: If no index is provided.
      ValueError: If one of the indexes is out of bounds.
    """
        super(ObservationFilterWrapper, self).__init__(env)
        idx = np.array(idx)
        if tf.nest.is_nested(env.observation_spec()):
            raise ValueError(
                'ObservationFilterWrapper only works with single-array '
                'observations (not nested).')
        if len(idx.shape) != 1:
            raise ValueError('ObservationFilterWrapper only works with '
                             'single-dimensional indexes for filtering.')
        if idx.shape[0] < 1:
            raise ValueError(
                'At least one index needs to be provided for filtering.')
        if not np.all(idx < env.observation_spec().shape[0]):
            raise ValueError('One of the indexes is out of bounds.')

        self._idx = idx
        self._observation_spec = array_spec.update_spec_shape(
            env.observation_spec(), idx.shape)
Exemplo n.º 2
0
 def _update_shape(spec):
   return array_spec.update_spec_shape(spec,
                                       (self._history_length,) + spec.shape)