def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec: """Returns an `EnvironmentSpec` describing values used by an environment.""" return EnvironmentSpec( observations=environment.observation_spec(), actions=environment.action_spec(), rewards=environment.reward_spec(), discounts=environment.discount_spec())
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = (observation, action, reward, discount, observation) key = np.array(0, np.uint64) probability = np.array(1.0, np.float64) table_size = np.array(1, np.int64) priority = np.array(1.0, np.float64) info = reverb.SampleInfo(key=key, probability=probability, table_size=table_size, priority=priority) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def __init__(self, environment: dm_env.Environment, name_filter: Optional[Sequence[str]] = None): """Initializes a new ConcatObservationWrapper. Args: environment: Environment to wrap. name_filter: Sequence of observation names to keep. None keeps them all. """ super().__init__(environment) observation_spec = environment.observation_spec() if name_filter is None: name_filter = list(observation_spec.keys()) self._obs_names = [ x for x in name_filter if x in observation_spec.keys() ] dummy_obs = _zeros_like(observation_spec) dummy_obs = self._convert_observation(dummy_obs) self._observation_spec = dm_env.specs.BoundedArray( shape=dummy_obs.shape, dtype=dummy_obs.dtype, minimum=-np.inf, maximum=np.inf, name='state')
def _make_ma_environment_spec( self, environment: dm_env.Environment) -> Dict[str, EnvironmentSpec]: """Returns an `EnvironmentSpec` describing values used by an environment for each agent.""" specs = {} observation_specs = environment.observation_spec() action_specs = environment.action_spec() reward_specs = environment.reward_spec() discount_specs = environment.discount_spec() self.extra_specs = environment.extra_spec() for agent in environment.possible_agents: specs[agent] = EnvironmentSpec( observations=observation_specs[agent], actions=action_specs[agent], rewards=reward_specs[agent], discounts=discount_specs[agent], ) return specs
def __init__( self, environment: dm_env.Environment, additional_discount: float = 0.99, max_abs_reward: Optional[float] = 1.0, resize_shape: Optional[Tuple[int, int]] = (84, 84), num_action_repeats: int = 4, num_pooled_frames: int = 2, zero_discount_on_life_loss: bool = True, num_stacked_frames: int = 4, grayscaling: bool = True, ): rgb_spec, unused_lives_spec = environment.observation_spec() if rgb_spec.shape[2] != 3: raise ValueError( 'This wrapper assumes interleaved pixel observations with shape ' '(height, width, channels).') if int(environment.action_spec().minimum) != 0: raise ValueError('This wrapper assumes zero-indexed actions.') self._environment = environment self._processor = atari( additional_discount=additional_discount, max_abs_reward=max_abs_reward, resize_shape=resize_shape, num_action_repeats=num_action_repeats, num_pooled_frames=num_pooled_frames, zero_discount_on_life_loss=zero_discount_on_life_loss, num_stacked_frames=num_stacked_frames, grayscaling=grayscaling, ) if grayscaling: self._observation_shape = resize_shape + (num_stacked_frames, ) self._observation_spec_name = 'grayscale' else: self._observation_shape = resize_shape + (3, num_stacked_frames) self._observation_spec_name = 'RGB' self._reset_next_step = True
def transition_iterator( environment: dm_env.Environment ) -> Callable[[int], Iterator[types.Transition]]: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: A callable that given a batch_size returns an iterator with demonstrations. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) dataset = tf.data.Dataset.from_tensors(data).repeat() return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) info = tree.map_structure( lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), reverb.SampleInfo.tf_dtypes()) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def __init__(self, environment: dm_env.Environment, *, max_abs_reward: Optional[float] = None, scale_dims: Optional[Tuple[int, int]] = (84, 84), action_repeats: int = 4, pooled_frames: int = 2, zero_discount_on_life_loss: bool = False, expose_lives_observation: bool = False, num_stacked_frames: int = 4, max_episode_len: Optional[int] = None, to_float: bool = False, grayscaling: bool = True): """Initializes a new AtariWrapper. Args: environment: An Atari environment. max_abs_reward: Maximum absolute reward value before clipping is applied. If set to `None` (default), no clipping is applied. scale_dims: Image size for the rescaling step after grayscaling, given as `(height, width)`. Set to `None` to disable resizing. action_repeats: Number of times to step wrapped environment for each given action. pooled_frames: Number of observations to pool over. Set to 1 to disable frame pooling. zero_discount_on_life_loss: If `True`, sets the discount to zero when the number of lives decreases in in Atari environment. expose_lives_observation: If `False`, the `lives` part of the observation is discarded, otherwise it is kept as part of an observation tuple. This does not affect the `zero_discount_on_life_loss` feature. When enabled, the observation consists of a single pixel array, otherwise it is a tuple (pixel_array, lives). num_stacked_frames: Number of recent (pooled) observations to stack into the returned observation. max_episode_len: Number of frames before truncating episode. By default, there is no maximum length. to_float: If `True`, rescales RGB observations to floats in [0, 1]. grayscaling: If `True` returns a grayscale version of the observations. In this case, the observation is 3D (H, W, num_stacked_frames). If `False` the observations are RGB and have shape (H, W, C, num_stacked_frames). Raises: ValueError: For various invalid inputs. """ if not 1 <= pooled_frames <= action_repeats: raise ValueError("pooled_frames ({}) must be between 1 and " "action_repeats ({}) inclusive".format( pooled_frames, action_repeats)) if zero_discount_on_life_loss: super().__init__(_ZeroDiscountOnLifeLoss(environment)) else: super().__init__(environment) if not max_episode_len: max_episode_len = np.inf self._frame_stacker = frame_stacking.FrameStacker( num_frames=num_stacked_frames) self._action_repeats = action_repeats self._pooled_frames = pooled_frames self._scale_dims = scale_dims self._max_abs_reward = max_abs_reward or np.inf self._to_float = to_float self._expose_lives_observation = expose_lives_observation if scale_dims: self._height, self._width = scale_dims else: spec = environment.observation_spec() self._height, self._width = spec[RGB_INDEX].shape[:2] self._episode_len = 0 self._max_episode_len = max_episode_len self._reset_next_step = True self._grayscaling = grayscaling # Based on underlying observation spec, decide whether lives are to be # included in output observations. observation_spec = self._environment.observation_spec() spec_names = [spec.name for spec in observation_spec] if "lives" in spec_names and spec_names.index("lives") != 1: raise ValueError( "`lives` observation needs to have index 1 in Atari.") self._observation_spec = self._init_observation_spec() self._raw_observation = None