Exemple #1
0
def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec:
  """Returns an `EnvironmentSpec` describing values used by an environment."""
  return EnvironmentSpec(
      observations=environment.observation_spec(),
      actions=environment.action_spec(),
      rewards=environment.reward_spec(),
      discounts=environment.discount_spec())
Exemple #2
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the
      observation, action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = (observation, action, reward, discount, observation)

    key = np.array(0, np.uint64)
    probability = np.array(1.0, np.float64)
    table_size = np.array(1, np.int64)
    priority = np.array(1.0, np.float64)
    info = reverb.SampleInfo(key=key,
                             probability=probability,
                             table_size=table_size,
                             priority=priority)
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Exemple #3
0
    def __init__(
            self,
            env: Environment,
            alpha: float,
            lamda: float,
            n_features: int,
            logger: Logger = NullLogger(),
    ):
        super().__init__()
        self.alpha = alpha
        self.lamda = lamda
        self.w = jnp.zeros(n_features + 1)
        self.z = jnp.zeros(n_features + 1)
        self.logger = logger
        self.policy = EGreedy(env.action_spec, 0.1)

        # temporary value estimate for the starting state
        self._actions = jnp.arange(env.action_spec().num_actions)
        self._q_old = jnp.zeros(env.action_spec().num_values)
Exemple #4
0
 def _make_ma_environment_spec(
         self,
         environment: dm_env.Environment) -> Dict[str, EnvironmentSpec]:
     """Returns an `EnvironmentSpec` describing values used by
     an environment for each agent."""
     specs = {}
     observation_specs = environment.observation_spec()
     action_specs = environment.action_spec()
     reward_specs = environment.reward_spec()
     discount_specs = environment.discount_spec()
     self.extra_specs = environment.extra_spec()
     for agent in environment.possible_agents:
         specs[agent] = EnvironmentSpec(
             observations=observation_specs[agent],
             actions=action_specs[agent],
             rewards=reward_specs[agent],
             discounts=discount_specs[agent],
         )
     return specs
    def __init__(
        self,
        environment: dm_env.Environment,
        additional_discount: float = 0.99,
        max_abs_reward: Optional[float] = 1.0,
        resize_shape: Optional[Tuple[int, int]] = (84, 84),
        num_action_repeats: int = 4,
        num_pooled_frames: int = 2,
        zero_discount_on_life_loss: bool = True,
        num_stacked_frames: int = 4,
        grayscaling: bool = True,
    ):
        rgb_spec, unused_lives_spec = environment.observation_spec()
        if rgb_spec.shape[2] != 3:
            raise ValueError(
                'This wrapper assumes interleaved pixel observations with shape '
                '(height, width, channels).')
        if int(environment.action_spec().minimum) != 0:
            raise ValueError('This wrapper assumes zero-indexed actions.')

        self._environment = environment
        self._processor = atari(
            additional_discount=additional_discount,
            max_abs_reward=max_abs_reward,
            resize_shape=resize_shape,
            num_action_repeats=num_action_repeats,
            num_pooled_frames=num_pooled_frames,
            zero_discount_on_life_loss=zero_discount_on_life_loss,
            num_stacked_frames=num_stacked_frames,
            grayscaling=grayscaling,
        )

        if grayscaling:
            self._observation_shape = resize_shape + (num_stacked_frames, )
            self._observation_spec_name = 'grayscale'
        else:
            self._observation_shape = resize_shape + (3, num_stacked_frames)
            self._observation_spec_name = 'RGB'

        self._reset_next_step = True
Exemple #6
0
def transition_iterator(
    environment: dm_env.Environment
) -> Callable[[int], Iterator[types.Transition]]:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    A callable that given a batch_size returns an iterator with demonstrations.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    dataset = tf.data.Dataset.from_tensors(data).repeat()

    return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
Exemple #7
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
 def __init__(self, environment: dm_env.Environment):
     self._environment = environment
     if int(environment.action_spec()[0].minimum) != 0:
         raise ValueError(
             'This wrapper assumes zero-indexed actions. Use the Atari setting '
             'zero_indexed_actions=\"true\" to get actions in this format.')
Exemple #9
0
 def __init__(self, environment: dm_env.Environment, clip: bool = False):
     super().__init__(environment)
     self._action_spec = environment.action_spec()
     self._clip = clip