Пример #1
0
def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec:
  """Returns an `EnvironmentSpec` describing values used by an environment."""
  return EnvironmentSpec(
      observations=environment.observation_spec(),
      actions=environment.action_spec(),
      rewards=environment.reward_spec(),
      discounts=environment.discount_spec())
Пример #2
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the
      observation, action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = (observation, action, reward, discount, observation)

    key = np.array(0, np.uint64)
    probability = np.array(1.0, np.float64)
    table_size = np.array(1, np.int64)
    priority = np.array(1.0, np.float64)
    info = reverb.SampleInfo(key=key,
                             probability=probability,
                             table_size=table_size,
                             priority=priority)
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Пример #3
0
    def __init__(self,
                 environment: dm_env.Environment,
                 name_filter: Optional[Sequence[str]] = None):
        """Initializes a new ConcatObservationWrapper.

    Args:
      environment: Environment to wrap.
      name_filter: Sequence of observation names to keep. None keeps them all.
    """
        super().__init__(environment)
        observation_spec = environment.observation_spec()
        if name_filter is None:
            name_filter = list(observation_spec.keys())
        self._obs_names = [
            x for x in name_filter if x in observation_spec.keys()
        ]

        dummy_obs = _zeros_like(observation_spec)
        dummy_obs = self._convert_observation(dummy_obs)
        self._observation_spec = dm_env.specs.BoundedArray(
            shape=dummy_obs.shape,
            dtype=dummy_obs.dtype,
            minimum=-np.inf,
            maximum=np.inf,
            name='state')
Пример #4
0
 def _make_ma_environment_spec(
         self,
         environment: dm_env.Environment) -> Dict[str, EnvironmentSpec]:
     """Returns an `EnvironmentSpec` describing values used by
     an environment for each agent."""
     specs = {}
     observation_specs = environment.observation_spec()
     action_specs = environment.action_spec()
     reward_specs = environment.reward_spec()
     discount_specs = environment.discount_spec()
     self.extra_specs = environment.extra_spec()
     for agent in environment.possible_agents:
         specs[agent] = EnvironmentSpec(
             observations=observation_specs[agent],
             actions=action_specs[agent],
             rewards=reward_specs[agent],
             discounts=discount_specs[agent],
         )
     return specs
Пример #5
0
    def __init__(
        self,
        environment: dm_env.Environment,
        additional_discount: float = 0.99,
        max_abs_reward: Optional[float] = 1.0,
        resize_shape: Optional[Tuple[int, int]] = (84, 84),
        num_action_repeats: int = 4,
        num_pooled_frames: int = 2,
        zero_discount_on_life_loss: bool = True,
        num_stacked_frames: int = 4,
        grayscaling: bool = True,
    ):
        rgb_spec, unused_lives_spec = environment.observation_spec()
        if rgb_spec.shape[2] != 3:
            raise ValueError(
                'This wrapper assumes interleaved pixel observations with shape '
                '(height, width, channels).')
        if int(environment.action_spec().minimum) != 0:
            raise ValueError('This wrapper assumes zero-indexed actions.')

        self._environment = environment
        self._processor = atari(
            additional_discount=additional_discount,
            max_abs_reward=max_abs_reward,
            resize_shape=resize_shape,
            num_action_repeats=num_action_repeats,
            num_pooled_frames=num_pooled_frames,
            zero_discount_on_life_loss=zero_discount_on_life_loss,
            num_stacked_frames=num_stacked_frames,
            grayscaling=grayscaling,
        )

        if grayscaling:
            self._observation_shape = resize_shape + (num_stacked_frames, )
            self._observation_spec_name = 'grayscale'
        else:
            self._observation_shape = resize_shape + (3, num_stacked_frames)
            self._observation_spec_name = 'RGB'

        self._reset_next_step = True
Пример #6
0
def transition_iterator(
    environment: dm_env.Environment
) -> Callable[[int], Iterator[types.Transition]]:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    A callable that given a batch_size returns an iterator with demonstrations.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    dataset = tf.data.Dataset.from_tensors(data).repeat()

    return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
Пример #7
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Пример #8
0
    def __init__(self,
                 environment: dm_env.Environment,
                 *,
                 max_abs_reward: Optional[float] = None,
                 scale_dims: Optional[Tuple[int, int]] = (84, 84),
                 action_repeats: int = 4,
                 pooled_frames: int = 2,
                 zero_discount_on_life_loss: bool = False,
                 expose_lives_observation: bool = False,
                 num_stacked_frames: int = 4,
                 max_episode_len: Optional[int] = None,
                 to_float: bool = False,
                 grayscaling: bool = True):
        """Initializes a new AtariWrapper.

    Args:
      environment: An Atari environment.
      max_abs_reward: Maximum absolute reward value before clipping is applied.
        If set to `None` (default), no clipping is applied.
      scale_dims: Image size for the rescaling step after grayscaling, given as
        `(height, width)`. Set to `None` to disable resizing.
      action_repeats: Number of times to step wrapped environment for each given
        action.
      pooled_frames: Number of observations to pool over. Set to 1 to disable
        frame pooling.
      zero_discount_on_life_loss: If `True`, sets the discount to zero when the
        number of lives decreases in in Atari environment.
      expose_lives_observation: If `False`, the `lives` part of the observation
        is discarded, otherwise it is kept as part of an observation tuple. This
        does not affect the `zero_discount_on_life_loss` feature. When enabled,
        the observation consists of a single pixel array, otherwise it is a
        tuple (pixel_array, lives).
      num_stacked_frames: Number of recent (pooled) observations to stack into
        the returned observation.
      max_episode_len: Number of frames before truncating episode. By default,
        there is no maximum length.
      to_float: If `True`, rescales RGB observations to floats in [0, 1].
      grayscaling: If `True` returns a grayscale version of the observations. In
        this case, the observation is 3D (H, W, num_stacked_frames). If `False`
        the observations are RGB and have shape (H, W, C, num_stacked_frames).

    Raises:
      ValueError: For various invalid inputs.
    """
        if not 1 <= pooled_frames <= action_repeats:
            raise ValueError("pooled_frames ({}) must be between 1 and "
                             "action_repeats ({}) inclusive".format(
                                 pooled_frames, action_repeats))

        if zero_discount_on_life_loss:
            super().__init__(_ZeroDiscountOnLifeLoss(environment))
        else:
            super().__init__(environment)

        if not max_episode_len:
            max_episode_len = np.inf

        self._frame_stacker = frame_stacking.FrameStacker(
            num_frames=num_stacked_frames)
        self._action_repeats = action_repeats
        self._pooled_frames = pooled_frames
        self._scale_dims = scale_dims
        self._max_abs_reward = max_abs_reward or np.inf
        self._to_float = to_float
        self._expose_lives_observation = expose_lives_observation

        if scale_dims:
            self._height, self._width = scale_dims
        else:
            spec = environment.observation_spec()
            self._height, self._width = spec[RGB_INDEX].shape[:2]

        self._episode_len = 0
        self._max_episode_len = max_episode_len
        self._reset_next_step = True

        self._grayscaling = grayscaling

        # Based on underlying observation spec, decide whether lives are to be
        # included in output observations.
        observation_spec = self._environment.observation_spec()
        spec_names = [spec.name for spec in observation_spec]
        if "lives" in spec_names and spec_names.index("lives") != 1:
            raise ValueError(
                "`lives` observation needs to have index 1 in Atari.")

        self._observation_spec = self._init_observation_spec()

        self._raw_observation = None