Пример #1
0
    def __init__(self, env: py_environment.PyEnvironment,
                 num_actions: np.ndarray):
        """Constructs a wrapper for discretizing the action space.

    **Note:** Only environments with a single BoundedArraySpec are supported.

    Args:
      env: Environment to wrap.
      num_actions: A np.array of the same shape as the environment's
        action_spec. Elements in the array specify the number of actions to
        discretize to for each dimension.

    Raises:
      ValueError: IF the action_spec shape and the limits shape are not equal.
    """
        super(ActionDiscretizeWrapper, self).__init__(env)

        action_spec = tf.nest.flatten(env.action_spec())
        if len(action_spec) != 1:
            raise ValueError(
                'ActionDiscretizeWrapper only supports environments with a single '
                'action spec. Got {}'.format(env.action_spec()))

        action_spec = action_spec[0]
        self._original_spec = action_spec
        self._num_actions = np.broadcast_to(num_actions, action_spec.shape)

        if action_spec.shape != self._num_actions.shape:
            raise ValueError(
                'Spec {} and limit shape do not match. Got {}'.format(
                    action_spec, self._num_actions.shape))

        self._discrete_spec, self._action_map = self._discretize_spec(
            action_spec, self._num_actions)
Пример #2
0
    def __init__(self, env: py_environment.PyEnvironment, flat_dtype=None):
        """Creates a FlattenActionWrapper.

    Args:
      env: Environment to wrap.
      flat_dtype: Optional, if set to a np.dtype the flat action_spec uses this
        dtype.

    Raises:
      ValueError: If any of the action_spec shapes ndim > 1.
      ValueError: If dtypes differ across action specs and flat_dtype is not
        set.
    """
        super(FlattenActionWrapper, self).__init__(env)
        self._original_action_spec = env.action_spec()
        flat_action_spec = tf.nest.flatten(env.action_spec())

        if any([len(s.shape) > 1 for s in flat_action_spec]):
            raise ValueError('ActionSpec shapes should all have ndim == 1.')

        if flat_dtype is None and any(
            [s.dtype != flat_action_spec[0].dtype for s in flat_action_spec]):
            raise ValueError(
                'All action_spec dtypes must match, or `flat_dtype` should be set.'
            )

        # shape or 1 to handle scalar shapes ().
        shape = sum([(s.shape and s.shape[0]) or 1 for s in flat_action_spec]),

        if all([
                isinstance(s, array_spec.BoundedArraySpec)
                for s in flat_action_spec
        ]):
            minimums = [
                np.broadcast_to(s.minimum, shape=s.shape)
                for s in flat_action_spec
            ]
            maximums = [
                np.broadcast_to(s.maximum, shape=s.shape)
                for s in flat_action_spec
            ]

            minimum = np.hstack(minimums)
            maximum = np.hstack(maximums)
            self._action_spec = array_spec.BoundedArraySpec(
                shape=shape,
                dtype=flat_dtype or flat_action_spec[0].dtype,
                minimum=minimum,
                maximum=maximum,
                name='FlattenedActionSpec')
        else:
            self._action_spec = array_spec.ArraySpec(
                shape=shape,
                dtype=flat_dtype or flat_action_spec[0].dtype,
                name='FlattenedActionSpec')

        self._flat_action_spec = flat_action_spec
Пример #3
0
    def __init__(
        self,
        env: py_environment.PyEnvironment,
        history_length: int = 3,
        include_actions: bool = False,
        tile_first_step_obs: bool = False,
    ):
        """Initializes a HistoryWrapper.

    Args:
      env: Environment to wrap.
      history_length: Length of the history to attach.
      include_actions: Whether actions should be included in the history.
      tile_first_step_obs: If True the observation on reset is tiled to fill the
       history.
    """
        super(HistoryWrapper, self).__init__(env)
        self._history_length = history_length
        self._include_actions = include_actions
        self._tile_first_step_obs = tile_first_step_obs

        self._zero_observation = self._zeros_from_spec(env.observation_spec())
        self._zero_action = self._zeros_from_spec(env.action_spec())

        self._observation_history = collections.deque(maxlen=history_length)
        self._action_history = collections.deque(maxlen=history_length)

        self._observation_spec = self._get_observation_spec()
Пример #4
0
def validate_py_environment(
    environment: py_environment.PyEnvironment,
    episodes: int = 5,
    observation_and_action_constraint_splitter: Optional[
        types.Splitter] = None):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter))

    if environment.batch_size is not None:
        batched_time_step_spec = array_spec.add_outer_dims_nest(
            time_step_spec, outer_dims=(environment.batch_size, ))
    else:
        batched_time_step_spec = time_step_spec

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, batched_time_step_spec):
            raise ValueError('Given `time_step`: %r does not match expected '
                             '`time_step_spec`: %r' %
                             (time_step, batched_time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        episode_count += np.sum(time_step.is_last())
Пример #5
0
    def __init__(self, env: py_environment.PyEnvironment,
                 num_extra_actions: int):
        """Initializes an instance of `ExtraDisabledActionsWrapper`.

    Args:
      env: The environment to wrap.
      num_extra_actions: The number of extra actions to add.
    """
        super(ExtraDisabledActionsWrapper, self).__init__(env)
        orig_action_spec = env.action_spec()
        self._action_spec = array_spec.BoundedArraySpec(
            shape=orig_action_spec.shape,
            dtype=orig_action_spec.dtype,
            minimum=orig_action_spec.minimum,
            maximum=orig_action_spec.maximum + num_extra_actions)
        mask_spec = array_spec.ArraySpec(
            shape=[self._action_spec.maximum - self._action_spec.minimum + 1],
            dtype=np.int64)
        self._masked_observation_spec = (env.observation_spec(), mask_spec)
        self._constant_mask = np.array(
            [[1] * (orig_action_spec.maximum - orig_action_spec.minimum + 1) +
             [0] * num_extra_actions] * self.batch_size)
Пример #6
0
def validate_py_environment(environment: py_environment.PyEnvironment,
                            episodes: int = 5):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec)

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, time_step_spec):
            raise ValueError(
                'Given `time_step`: %r does not match expected `time_step_spec`: %r'
                % (time_step, time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        if time_step.is_last():
            episode_count += 1
            time_step = environment.reset()