示例#1
0
文件: wrappers.py 项目: wau/agents
    def __init__(self, env: py_environment.PyEnvironment,
                 idx: Union[Sequence[int], np.ndarray]):
        """Creates an observation filter wrapper.

    Args:
      env: Environment to wrap.
      idx: Array of indexes pointing to elements to include in output.

    Raises:
      ValueError: If observation spec is nested.
      ValueError: If indexes are not single-dimensional.
      ValueError: If no index is provided.
      ValueError: If one of the indexes is out of bounds.
    """
        super(ObservationFilterWrapper, self).__init__(env)
        idx = np.array(idx)
        if tf.nest.is_nested(env.observation_spec()):
            raise ValueError(
                'ObservationFilterWrapper only works with single-array '
                'observations (not nested).')
        if len(idx.shape) != 1:
            raise ValueError('ObservationFilterWrapper only works with '
                             'single-dimensional indexes for filtering.')
        if idx.shape[0] < 1:
            raise ValueError(
                'At least one index needs to be provided for filtering.')
        if not np.all(idx < env.observation_spec().shape[0]):
            raise ValueError('One of the indexes is out of bounds.')

        self._idx = idx
        self._observation_spec = env.observation_spec().replace(
            shape=idx.shape)
示例#2
0
文件: wrappers.py 项目: wau/agents
    def __init__(
        self,
        env: py_environment.PyEnvironment,
        history_length: int = 3,
        include_actions: bool = False,
        tile_first_step_obs: bool = False,
    ):
        """Initializes a HistoryWrapper.

    Args:
      env: Environment to wrap.
      history_length: Length of the history to attach.
      include_actions: Whether actions should be included in the history.
      tile_first_step_obs: If True the observation on reset is tiled to fill the
       history.
    """
        super(HistoryWrapper, self).__init__(env)
        self._history_length = history_length
        self._include_actions = include_actions
        self._tile_first_step_obs = tile_first_step_obs

        self._zero_observation = self._zeros_from_spec(env.observation_spec())
        self._zero_action = self._zeros_from_spec(env.action_spec())

        self._observation_history = collections.deque(maxlen=history_length)
        self._action_history = collections.deque(maxlen=history_length)

        self._observation_spec = self._get_observation_spec()
示例#3
0
文件: wrappers.py 项目: wau/agents
    def __init__(self, env: py_environment.PyEnvironment,
                 num_extra_actions: int):
        """Initializes an instance of `ExtraDisabledActionsWrapper`.

    Args:
      env: The environment to wrap.
      num_extra_actions: The number of extra actions to add.
    """
        super(ExtraDisabledActionsWrapper, self).__init__(env)
        orig_action_spec = env.action_spec()
        self._action_spec = array_spec.BoundedArraySpec(
            shape=orig_action_spec.shape,
            dtype=orig_action_spec.dtype,
            minimum=orig_action_spec.minimum,
            maximum=orig_action_spec.maximum + num_extra_actions)
        mask_spec = array_spec.ArraySpec(
            shape=[self._action_spec.maximum - self._action_spec.minimum + 1],
            dtype=np.int64)
        self._masked_observation_spec = (env.observation_spec(), mask_spec)
        self._constant_mask = np.array(
            [[1] * (orig_action_spec.maximum - orig_action_spec.minimum + 1) +
             [0] * num_extra_actions] * self.batch_size)
示例#4
0
文件: wrappers.py 项目: wau/agents
    def __init__(self,
                 env: py_environment.PyEnvironment,
                 observations_allowlist: Optional[Sequence[Text]] = None):
        """Initializes a wrapper to flatten environment observations.

    Args:
      env: A `py_environment.PyEnvironment` environment to wrap.
      observations_allowlist: A list of observation keys that want to be
        observed from the environment.  All other observations returned are
        filtered out.  If not provided, all observations will be kept.
        Additionally, if this is provided, the environment is expected to return
        a dictionary of observations.

    Raises:
      ValueError: If the current environment does not return a dictionary of
        observations and observations_allowlist is provided.
      ValueError: If the observation_allowlist keys are not found in the
        environment.
    """
        super(FlattenObservationsWrapper, self).__init__(env)

        # If observations allowlist is provided:
        #  Check that the environment returns a dictionary of observations.
        #  Check that the set of allowed keys is a found in the environment keys.
        if observations_allowlist is not None:
            if not isinstance(env.observation_spec(), dict):
                raise ValueError(
                    'If you provide an observations allowlist, the current environment '
                    'must return a dictionary of observations! The returned observation'
                    ' spec is type %s.' % (type(env.observation_spec())))

            # Check that observation allowlist keys are valid observation keys.
            if not (set(observations_allowlist).issubset(
                    env.observation_spec().keys())):
                raise ValueError(
                    'The observation allowlist contains keys not found in the '
                    'environment! Unknown keys: %s' % list(
                        set(observations_allowlist).difference(
                            env.observation_spec().keys())))

        # Check that all observations have the same dtype. This dtype will be used
        # to create the flattened ArraySpec.
        env_dtypes = list(
            set([obs.dtype for obs in env.observation_spec().values()]))
        if len(env_dtypes) != 1:
            raise ValueError(
                'The observation spec must all have the same dtypes! '
                'Currently found dtypes: %s' % (env_dtypes))
        inferred_spec_dtype = env_dtypes[0]

        self._observation_spec_dtype = inferred_spec_dtype
        self._observations_allowlist = observations_allowlist
        # Update the observation spec in the environment.
        observations_spec = env.observation_spec()
        if self._observations_allowlist is not None:
            observations_spec = self._filter_observations(observations_spec)

        # Compute the observation length after flattening the observation items and
        # nested structure. Observation specs are not batched.
        observation_total_len = sum(
            int(np.prod(observation.shape))
            for observation in self._flatten_nested_observations(
                observations_spec, is_batched=False))

        # Update the observation spec as an array of one-dimension.
        self._flattened_observation_spec = array_spec.ArraySpec(
            shape=(observation_total_len, ),
            dtype=self._observation_spec_dtype,
            name='packed_observations')