Example #1
0
    def test_pugail(self):
        def dummy_discriminator(
                state: losses.State,
                transition: types.Transition) -> losses.DiscriminatorOutput:
            return transition.observation, state

        zero_transition = types.Transition(.1, 0., 0., 0., 0.)
        zero_transition = tree.map_structure(
            lambda x: jnp.expand_dims(x, axis=0), zero_transition)

        one_transition = types.Transition(1., 0., 0., 0., 0.)
        one_transition = tree.map_structure(
            lambda x: jnp.expand_dims(x, axis=0), one_transition)

        prior = .7
        loss_fn = losses.pugail_loss(positive_class_prior=prior,
                                     entropy_coefficient=0.)
        loss, _ = loss_fn(dummy_discriminator, {}, one_transition,
                          zero_transition, ())

        d_one = jax.nn.sigmoid(dummy_discriminator({}, one_transition)[0])
        d_zero = jax.nn.sigmoid(dummy_discriminator({}, zero_transition)[0])
        expected_loss = -prior * jnp.log(
            d_one) + -jnp.log(1. - d_zero) - prior * -jnp.log(1 - d_one)

        self.assertAlmostEqual(loss, expected_loss, places=6)
Example #2
0
    def test_weighted_generator(self):
        data0 = types.Transition(np.array([[1], [2], [3]]), (), _REWARD, (),
                                 ())
        it0 = iter([data0])

        data1 = types.Transition(np.array([[4], [5], [6]]), (), _REWARD, (),
                                 ())
        data2 = types.Transition(np.array([[7], [8], [9]]), (), _REWARD, (),
                                 ())
        it1 = iter([
            reverb.ReplaySample(info=reverb.SampleInfo(
                *[() for _ in reverb.SampleInfo.tf_dtypes()]),
                                data=data1),
            reverb.ReplaySample(info=reverb.SampleInfo(
                *[() for _ in reverb.SampleInfo.tf_dtypes()]),
                                data=data2)
        ])

        weighted_it = builder._generate_samples_with_demonstrations(
            it0, it1, policy_to_expert_data_ratio=2, batch_size=3)

        np.testing.assert_array_equal(
            next(weighted_it).data.observation, np.array([[1], [4], [5]]))
        np.testing.assert_array_equal(
            next(weighted_it).data.observation, np.array([[7], [8], [2]]))
        self.assertRaises(StopIteration, lambda: next(weighted_it))
Example #3
0
def replay_sample_to_sars_transition(
    sample: reverb.ReplaySample,
    is_sequence: bool) -> types.Transition:
  """Converts the replay sample to a types.Transition.

  NB: If is_sequence is True then the last next_observation of each sequence is
  rubbish. Don't train on it.

  Args:
    sample: The replay sample
    is_sequence: If False we expect the sample data to match the
      types.Transition already. Otherwise we expect a batch of sequences of
      steps.

  Returns:
    A types.Transition built from the sample data. The number of leading
    dimensions will be unchanged, so expect 2 for sequence based ([Batch, Time])
    and 1 ([Batch]) otherwise.
    NB: If is_sequence is True then the last next_observation of each sequence
    is rubbish. Don't train on it.
  """
  if not is_sequence:
    return types.Transition(*sample.data)
  # Note that the last next_observation is invalid.
  steps = sample.data
  return types.Transition(
      observation=steps.observation,
      action=steps.action,
      reward=steps.reward,
      discount=steps.discount,
      next_observation=jnp.roll(steps.observation, shift=-1, axis=1))
Example #4
0
def replay_sample_to_sars_transition(
        sample: reverb.ReplaySample,
        is_sequence: bool,
        strip_last_transition: bool = False,
        flatten_batch: bool = False) -> types.Transition:
    """Converts the replay sample to a types.Transition.

  NB: If is_sequence is True then the last next_observation of each sequence is
  rubbish. Don't train on it.

  Args:
    sample: The replay sample
    is_sequence: If False we expect the sample data to match the
      types.Transition already. Otherwise we expect a batch of sequences of
      steps.
    strip_last_transition: If True and is_sequence, the last transition will be
      stripped as its next_observation field is incorrect.
    flatten_batch: If True and is_sequence, the two batch dimensions will be
      flatten to one.

  Returns:
    A types.Transition built from the sample data.
    If is_sequence and strip_last_transition are both True, the output will be
    smaller than the output as the last transition of every sequence will have
    been removed.
  """
    if not is_sequence:
        return types.Transition(*sample.data)
    # Note that the last next_observation is invalid.
    steps = sample.data

    def roll(observation):
        return np.roll(observation, shift=-1, axis=1)

    transitions = types.Transition(observation=steps.observation,
                                   action=steps.action,
                                   reward=steps.reward,
                                   discount=steps.discount,
                                   next_observation=tree.map_structure(
                                       roll, steps.observation),
                                   extras=steps.extras)
    if strip_last_transition:
        # We remove the last transition as its next_observation field is incorrect.
        # It has been obtained by rolling the observation field, such that
        # transitions.next_observations[:, -1] is transitions.observations[:, 0]
        transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions)
    if flatten_batch:
        # Merge the 2 leading batch dimensions into 1.
        transitions = jax.tree_map(
            lambda x: np.reshape(x, (-1, ) + x.shape[2:]), transitions)
    return transitions
Example #5
0
    def test_step(self):
        simple_spec = specs.Array(shape=(), dtype=float)

        spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec,
                                     simple_spec)

        discriminator = _make_discriminator(spec)
        ail_network = ail_networks.AILNetworks(discriminator,
                                               imitation_reward_fn=lambda x: x,
                                               direct_rl_networks=None)

        loss = losses.gail_loss()

        optimizer = optax.adam(.01)

        step = jax.jit(
            functools.partial(ail_learning.ail_update_step,
                              optimizer=optimizer,
                              ail_network=ail_network,
                              loss_fn=loss))

        zero_transition = types.Transition(np.array([0.]), np.array([0.]), 0.,
                                           0., np.array([0.]))
        zero_transition = utils.add_batch_dim(zero_transition)

        one_transition = types.Transition(np.array([1.]), np.array([0.]), 0.,
                                          0., np.array([0.]))
        one_transition = utils.add_batch_dim(one_transition)

        key = jax.random.PRNGKey(0)
        discriminator_params, discriminator_state = discriminator.init(key)

        state = ail_learning.DiscriminatorTrainingState(
            optimizer_state=optimizer.init(discriminator_params),
            discriminator_params=discriminator_params,
            discriminator_state=discriminator_state,
            policy_params=None,
            key=key,
            steps=0,
        )

        expected_loss = [1.062, 1.057, 1.052]

        for i in range(3):
            state, loss = step(state, (one_transition, zero_transition))
            self.assertAlmostEqual(loss['total_loss'],
                                   expected_loss[i],
                                   places=3)
Example #6
0
 def test_sqil_iterator(self):
     demonstrations = [
         types.Transition(np.array([[1], [2], [3]]), (), (), (), ())
     ]
     replay = [
         reverb.ReplaySample(info=(),
                             data=types.Transition(
                                 np.array([[4], [5], [6]]), (), (), (), ()))
     ]
     sqil_it = builder._generate_sqil_samples(iter(demonstrations),
                                              iter(replay))
     np.testing.assert_array_equal(
         next(sqil_it).data.observation, np.array([[1], [3], [5]]))
     np.testing.assert_array_equal(
         next(sqil_it).data.observation, np.array([[2], [4], [6]]))
     self.assertRaises(StopIteration, lambda: next(sqil_it))
Example #7
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    key = np.array(0, np.uint64)
    probability = np.array(1.0, np.float64)
    table_size = np.array(1, np.int64)
    priority = np.array(1.0, np.float64)
    info = reverb.SampleInfo(key=key,
                             probability=probability,
                             table_size=table_size,
                             priority=priority)
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Example #8
0
    def signature(cls,
                  environment_spec: specs.EnvironmentSpec,
                  extras_spec: types.NestedSpec = ()):

        # This function currently assumes that self._discount is a scalar.
        # If it ever becomes a nested structure and/or a np.ndarray, this method
        # will need to know its structure / shape. This is because the signature
        # discount shape is the environment's discount shape and this adder's
        # discount shape broadcasted together. Also, the reward shape is this
        # signature discount shape broadcasted together with the environment
        # reward shape. As long as self._discount is a scalar, it will not affect
        # either the signature discount shape nor the signature reward shape, so we
        # can ignore it.

        rewards_spec, step_discounts_spec = tree_utils.broadcast_structures(
            environment_spec.rewards, environment_spec.discounts)
        rewards_spec = tree.map_structure(_broadcast_specs, rewards_spec,
                                          step_discounts_spec)
        step_discounts_spec = tree.map_structure(copy.deepcopy,
                                                 step_discounts_spec)

        transition_spec = types.Transition(
            environment_spec.observations,
            environment_spec.actions,
            rewards_spec,
            step_discounts_spec,
            environment_spec.observations,  # next_observation
            extras_spec)

        return tree.map_structure_with_path(base.spec_like_to_tensor_spec,
                                            transition_spec)
Example #9
0
 def _create_dummy_transitions(self):
     return types.Transition(observation=self._DUMMY_OBS,
                             action=self._DUMMY_ACTION,
                             reward=self._DUMMY_REWARD,
                             discount=self._DUMMY_DISCOUNT,
                             next_observation=self._DUMMY_NEXT_OBS,
                             extras={'return': self._DUMMY_RETURN})
Example #10
0
    def step(self):
        sample = next(self._iterator)
        transitions = types.Transition(*sample.data)

        counts = self._counter.get_counts()
        if 'learner_steps' not in counts:
            cur_step = 0
        else:
            cur_step = counts['learner_steps']
        in_initial_bc_iters = cur_step < self._num_bc_iters

        if in_initial_bc_iters:
            self._state, metrics = self._update_step_in_initial_bc_iters(
                self._state, transitions)
        else:
            self._state, metrics = self._update_step_rest(
                self._state, transitions)

        # self._state, metrics = self._update_step(self._state, transitions)

        # Compute elapsed time.
        timestamp = time.time()
        elapsed_time = timestamp - self._timestamp if self._timestamp else 0
        self._timestamp = timestamp

        # Increment counts and record the current time
        counts = self._counter.increment(steps=self._num_sgd_steps_per_step,
                                         walltime=elapsed_time)

        # Attempts to write the logs.
        self._logger.write({**metrics, **counts})
Example #11
0
def _n_step_transition_from_episode(observations: acme_types.NestedTensor,
                                    actions: tf.Tensor,
                                    rewards: tf.Tensor,
                                    discounts: tf.Tensor,
                                    n_step: int,
                                    discount: float):
  """Produce Reverb-like N-step transition from a full episode.

  Observations, actions, rewards and discounts have the same length. This
  function will ignore the first reward and discount and the last action.

  Args:
    observations: [L, ...] Tensor.
    actions: [L, ...] Tensor.
    rewards: [L] Tensor.
    discounts: [L] Tensor.
    n_step: number of steps to squash into a single transition.
    discount: discount to use for TD updates.

  Returns:
    (o_t, a_t, r_t, d_t, o_tp1) tuple.
  """

  max_index = tf.shape(rewards)[0] - 1
  first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1,
                            dtype=tf.int32)
  last = tf.minimum(first + n_step, max_index)

  o_t = tree.map_structure(operator.itemgetter(first), observations)
  a_t = tree.map_structure(operator.itemgetter(first), actions)
  o_tp1 = tree.map_structure(operator.itemgetter(last), observations)

  # 0, 1, ..., n-1.
  discount_range = tf.cast(tf.range(last - first), tf.float32)
  # 1, g, ..., g^{n-1}.
  additional_discounts = tf.pow(discount, discount_range)
  # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}.
  discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last-1])], 0)
  # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}.
  discounts *= additional_discounts
  # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1}
  # We have to shift rewards by one so last=max_index corresponds to transitions
  # that include the last reward.
  r_t = tf.reduce_sum(rewards[first+1:last+1] * discounts)

  # g^{n-1} * d_{t} * ... * d_{t+n-1}.
  d_t = discounts[-1]

  key = tf.constant(0, tf.uint64)
  probability = tf.constant(1.0, tf.float64)
  table_size = tf.constant(1, tf.int64)
  priority = tf.constant(1.0, tf.float64)
  info = reverb.SampleInfo(
      key=key,
      probability=probability,
      table_size=table_size,
      priority=priority)
  return reverb.ReplaySample(
      info=info, data=acme_types.Transition(o_t, a_t, r_t, d_t, o_tp1))
Example #12
0
def _episode_to_transition(step: Dict[str, Any]) -> types.Transition:
    return types.Transition(
        observation=step['observation'][:-1],
        action=step['action'][:-1],
        reward=step['reward'][:-1],
        discount=1.0 - tf.cast(step['is_terminal'][1:], dtype=tf.float32),
        # If next step is terminal, then the observation may be arbitrary.
        next_observation=step['observation'][1:],
    )
Example #13
0
def _batched_step_to_transition(step: rlds.BatchedStep) -> types.Transition:
    return types.Transition(
        observation=tf.nest.map_structure(lambda x: x[0],
                                          step[rlds.OBSERVATION]),
        action=tf.nest.map_structure(lambda x: x[0], step[rlds.ACTION]),
        reward=tf.nest.map_structure(lambda x: x[0], step[rlds.REWARD]),
        discount=1.0 - tf.cast(step[rlds.IS_TERMINAL][1], dtype=tf.float32),
        # If next step is terminal, then the observation may be arbitrary.
        next_observation=tf.nest.map_structure(lambda x: x[1],
                                               step[rlds.OBSERVATION]))
Example #14
0
    def test_gradient_penalty(self):
        def dummy_discriminator(
                transition: types.Transition) -> networks_lib.Logits:
            return transition.observation + jnp.square(transition.action)

        zero_transition = types.Transition(0., 0., 0., 0., 0.)
        zero_transition = tree.map_structure(
            lambda x: jnp.expand_dims(x, axis=0), zero_transition)
        self.assertEqual(
            losses._compute_gradient_penalty(zero_transition,
                                             dummy_discriminator, 0.),
            1**2 + 0**2)

        one_transition = types.Transition(1., 1., 0., 0., 0.)
        one_transition = tree.map_structure(
            lambda x: jnp.expand_dims(x, axis=0), one_transition)
        self.assertEqual(
            losses._compute_gradient_penalty(one_transition,
                                             dummy_discriminator, 0.),
            1**2 + 2**2)
Example #15
0
    def test_discrete_actions(self, loss_name):
        with chex.fake_pmap_and_jit():

            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.DiscreteEnvironment(num_actions=10,
                                                    num_observations=100,
                                                    obs_shape=(10, ),
                                                    obs_dtype=np.float32)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec, discrete_actions=True)

            def logp_fn(logits, actions):
                max_logits = jnp.max(logits, axis=-1, keepdims=True)
                logits = logits - max_logits
                logits_actions = jnp.sum(
                    jax.nn.one_hot(actions, spec.actions.num_values) * logits,
                    axis=-1)

                log_prob = logits_actions - special.logsumexp(logits, axis=-1)
                return log_prob

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=logp_fn)

            elif loss_name == 'rcal':
                base_loss_fn = bc.logp(logp_fn=logp_fn)
                loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1)

            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Example #16
0
def _step_to_transition(rlds_step: rlds.BatchedStep) -> types.Transition:
    """Converts batched RLDS steps to batched transitions."""
    return types.Transition(
        observation=rlds_step[rlds.OBSERVATION],
        action=rlds_step[rlds.ACTION],
        reward=rlds_step[rlds.REWARD],
        discount=rlds_step[rlds.DISCOUNT],
        #  We provide next_observation if an algorithm needs it, however note that
        # it will only contain s_t and s_t+1, so will be one element short of all
        # other attributes (which contain s_t-1, s_t, s_t+1).
        next_observation=tree.map_structure(lambda x: x[1:],
                                            rlds_step[rlds.OBSERVATION]),
        extras={
            N_STEP_RETURN: rlds_step[N_STEP_RETURN],
        })
Example #17
0
    def step(self):
        sample = next(self._iterator)
        transitions = types.Transition(*sample.data)

        self._state, metrics = self._update_step(self._state, transitions)

        # Compute elapsed time.
        timestamp = time.time()
        elapsed_time = timestamp - self._timestamp if self._timestamp else 0
        self._timestamp = timestamp

        # Increment counts and record the current time
        counts = self._counter.increment(steps=1, walltime=elapsed_time)

        # Attempts to write the logs.
        self._logger.write({**metrics, **counts})
Example #18
0
    def test_make_dataset_transition_adder(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            server_address=self.server_address,
            environment_spec=environment_spec,
            transition_adder=True)

        environment_spec = types.Transition(
            observation=environment_spec.observations,
            action=environment_spec.actions,
            reward=environment_spec.rewards,
            discount=environment_spec.discounts,
            next_observation=environment_spec.observations,
            extras=())

        self.assertTrue(
            _check_specs(environment_spec, dataset.element_spec.data))
Example #19
0
  def step(self):
    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)
    transitions = types.Transition(*sample.data)

    self._state, metrics = self._sgd_step(self._state, transitions)

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(steps=1, walltime=elapsed_time)

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})
Example #20
0
    def test_continuous_actions(self, loss_name):
        with chex.fake_pmap_and_jit():
            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.ContinuousEnvironment(episode_length=10,
                                                      bounded=True,
                                                      action_dim=6)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec)

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                  dist_params.log_prob(actions))
            elif loss_name == 'mse':
                loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params
                                 .sample(seed=key))
            elif loss_name == 'peerbc':
                base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                       dist_params.log_prob(actions))
                loss_fn = bc.peerbc(base_loss_fn, zeta=0.1)
            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Example #21
0
  def step(self):
    with jax.profiler.StepTraceAnnotation('sampling batch'):
      sample = next(self._iterator)
    transitions = types.Transition(*sample.data)

    with jax.profiler.StepTraceAnnotation('train step'):
      self._state, metrics = self._update_step(self._state, transitions)

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(
        steps=self._num_sgd_steps_per_step, walltime=elapsed_time)

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})
Example #22
0
def transition_iterator_from_spec(
    spec: specs.EnvironmentSpec
) -> Callable[[int], Iterator[types.Transition]]:
    """Constructs fake iterator of transitions.

  Args:
    spec: Constructed fake transitions match the provided specification..

  Returns:
    A callable that given a batch_size returns an iterator of transitions.
  """

    observation = _generate_from_spec(spec.observations)
    action = _generate_from_spec(spec.actions)
    reward = _generate_from_spec(spec.rewards)
    discount = _generate_from_spec(spec.discounts)
    data = types.Transition(observation, action, reward, discount, observation)

    dataset = tf.data.Dataset.from_tensors(data).repeat()

    return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
Example #23
0
        def flatten_fn(sample):
            seq_len = tf.shape(sample.data.observation)[0]
            arange = tf.range(seq_len)
            is_future_mask = tf.cast(arange[:, None] < arange[None],
                                     tf.float32)
            discount = self._config.discount**tf.cast(arange[None] - arange[:, None], tf.float32)  # pylint: disable=line-too-long
            probs = is_future_mask * discount
            # The indexing changes the shape from [seq_len, 1] to [seq_len]
            goal_index = tf.random.categorical(logits=tf.math.log(probs),
                                               num_samples=1)[:, 0]
            state = sample.data.observation[:-1, :self._config.obs_dim]
            next_state = sample.data.observation[1:, :self._config.obs_dim]

            # Create the goal observations in three steps.
            # 1. Take all future states (not future goals).
            # 2. Apply obs_to_goal.
            # 3. Sample one of the future states. Note that we don't look for a goal
            # for the final state, because there are no future states.
            goal = sample.data.observation[:, :self._config.obs_dim]
            goal = contrastive_utils.obs_to_goal_2d(
                goal,
                start_index=self._config.start_index,
                end_index=self._config.end_index)
            goal = tf.gather(goal, goal_index[:-1])
            new_obs = tf.concat([state, goal], axis=1)
            new_next_obs = tf.concat([next_state, goal], axis=1)
            transition = types.Transition(observation=new_obs,
                                          action=sample.data.action[:-1],
                                          reward=sample.data.reward[:-1],
                                          discount=sample.data.discount[:-1],
                                          next_observation=new_next_obs,
                                          extras={
                                              'next_action':
                                              sample.data.action[1:],
                                          })
            # Shift for the transpose_shuffle.
            shift = tf.random.uniform((), 0, seq_len, tf.int32)
            transition = tree.map_structure(
                lambda t: tf.roll(t, shift, axis=0), transition)
            return transition
Example #24
0
def transition_iterator(
    environment: dm_env.Environment
) -> Callable[[int], Iterator[types.Transition]]:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    A callable that given a batch_size returns an iterator with demonstrations.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    dataset = tf.data.Dataset.from_tensors(data).repeat()

    return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
Example #25
0
  def step(self):
    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    # TODO(raveman): Add a support for offline training, where we do not consume
    # data from the replay buffer.
    sample = next(self._iterator_replay)
    replay_transitions = types.Transition(*sample.data)

    # Get a batch of Transitions from the demonstration.
    demonstration_transitions = next(self._iterator_demonstrations)

    self._state, metrics = self._sgd_step(
        self._state, (replay_transitions, demonstration_transitions))

    # Compute elapsed time.
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(steps=1, walltime=elapsed_time)

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})
Example #26
0
def transition_dataset_from_spec(
        spec: specs.EnvironmentSpec) -> tf.data.Dataset:
    """Constructs fake dataset of Reverb N-step transition samples.

  Args:
    spec: Constructed fake transitions match the provided specification.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverbSample
    object indefinitely.
  """

    observation = _generate_from_spec(spec.observations)
    action = _generate_from_spec(spec.actions)
    reward = _generate_from_spec(spec.rewards)
    discount = _generate_from_spec(spec.discounts)
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Example #27
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Example #28
0
# expected transitions that should result from this trajectory. The expected
# transitions are of the form: (observation, action, reward, discount,
# next_observation, extras).
TEST_CASES = [
    dict(
        testcase_name='OneStepFinalReward',
        n_step=1,
        additional_discount=1.0,
        first=dm_env.restart(1),
        steps=(
            (0, dm_env.transition(reward=0.0, observation=2)),
            (0, dm_env.transition(reward=0.0, observation=3)),
            (0, dm_env.termination(reward=1.0, observation=4)),
        ),
        expected_transitions=(
            types.Transition(1, 0, 0.0, 1.0, 2),
            types.Transition(2, 0, 0.0, 1.0, 3),
            types.Transition(3, 0, 1.0, 0.0, 4),
        )),
    dict(
        testcase_name='OneStepDict',
        n_step=1,
        additional_discount=1.0,
        first=dm_env.restart({'foo': 1}),
        steps=(
            (0, dm_env.transition(reward=0.0, observation={'foo': 2})),
            (0, dm_env.transition(reward=0.0, observation={'foo': 3})),
            (0, dm_env.termination(reward=1.0, observation={'foo': 4})),
        ),
        expected_transitions=(
            types.Transition({'foo': 1}, 0, 0.0, 1.0, {'foo': 2}),
Example #29
0
def _spec_to_shapes_and_dtypes(transition_adder: bool,
                               environment_spec: specs.EnvironmentSpec,
                               extra_spec: Optional[types.NestedSpec],
                               sequence_length: Optional[int],
                               convert_zero_size_to_none: bool,
                               using_deprecated_adder: bool):
    """Creates the shapes and dtypes needed to describe the Reverb dataset.

  This takes a `environment_spec`, `extra_spec`, and additional information and
  returns a tuple (shapes, dtypes) that describe the data contained in Reverb.

  Args:
    transition_adder: A boolean, describing if a `TransitionAdder` was used to
      add data.
    environment_spec: A `specs.EnvironmentSpec`, describing the shapes and
      dtypes of the data produced by the environment (and the action).
    extra_spec: A nested structure of objects with a `.shape` and `.dtype`
      property. This describes any additional data the Actor adds into Reverb.
    sequence_length: An optional integer for how long the added sequences are,
      only used with `SequenceAdder`.
    convert_zero_size_to_none: If True, then all shape dimensions that are 0 are
      converted to None. A None dimension is only set at runtime.
    using_deprecated_adder: True if the adder used to generate the data is
      from acme/adders/reverb/deprecated.

  Returns:
    A tuple (dtypes, shapes) that describes the data that has been added into
    Reverb.
  """
    # The *transition* adder is special in that it also adds an arrival state.
    if transition_adder:
        adder_spec = types.Transition(
            observation=environment_spec.observations,
            action=environment_spec.actions,
            reward=environment_spec.rewards,
            discount=environment_spec.discounts,
            next_observation=environment_spec.observations,
            extras=() if not extra_spec else extra_spec)
    elif using_deprecated_adder and deprecated_base is not None:
        adder_spec = deprecated_base.Step(
            observation=environment_spec.observations,
            action=environment_spec.actions,
            reward=environment_spec.rewards,
            discount=environment_spec.discounts,
            extras=() if not extra_spec else extra_spec)
    else:
        adder_spec = adders.Step(observation=environment_spec.observations,
                                 action=environment_spec.actions,
                                 reward=environment_spec.rewards,
                                 discount=environment_spec.discounts,
                                 start_of_episode=specs.Array(shape=(),
                                                              dtype=bool),
                                 extras=() if not extra_spec else extra_spec)

    # Extract the shapes and dtypes from these specs.
    get_dtype = lambda x: tf.as_dtype(x.dtype)
    get_shape = lambda x: tf.TensorShape(x.shape)
    if sequence_length:
        get_shape = lambda x: tf.TensorShape([sequence_length, *x.shape])

    if convert_zero_size_to_none:
        # TODO(b/143692455): Consider making this default behaviour.
        get_shape = lambda x: tf.TensorShape(
            [s if s else None for s in x.shape])
    shapes = tree.map_structure(get_shape, adder_spec)
    dtypes = tree.map_structure(get_dtype, adder_spec)
    return shapes, dtypes
Example #30
0
    def _write(self):
        # NOTE: we do not check that the buffer is of length N here. This means
        # that at the beginning of an episode we will add the initial N-1
        # transitions (of size 1, 2, ...) and at the end of an episode (when
        # called from write_last) we will write the final transitions of size (N,
        # N-1, ...). See the Note in the docstring.

        # Form the n-step transition given the steps.
        observation = self._buffer[0].observation
        action = self._buffer[0].action
        extras = self._buffer[0].extras
        next_observation = self._next_observation

        # Give the same tree structure to the n-step return accumulator,
        # n-step discount accumulator, and self.discount, so that they can be
        # iterated in parallel using tree.map_structure.
        (n_step_return, total_discount,
         self_discount) = tree_utils.broadcast_structures(
             self._buffer[0].reward, self._buffer[0].discount, self._discount)

        # Copy total_discount, so that accumulating into it doesn't affect
        # _buffer[0].discount.
        total_discount = tree.map_structure(np.copy, total_discount)

        # Broadcast n_step_return to have the broadcasted shape of
        # reward * discount. Also copy, to avoid accumulating into
        # _buffer[0].reward.
        n_step_return = tree.map_structure(
            lambda r, d: np.copy(np.broadcast_to(r,
                                                 np.broadcast(r, d).shape)),
            n_step_return, total_discount)

        # NOTE: total discount will have one less discount than it does
        # step.discounts. This is so that when the learner/update uses an additional
        # discount we don't apply it twice. Inside the following loop we will
        # apply this right before summing up the n_step_return.
        for step in itertools.islice(self._buffer, 1, None):
            (step_discount, step_reward,
             total_discount) = tree_utils.broadcast_structures(
                 step.discount, step.reward, total_discount)

            # Equivalent to: `total_discount *= self._discount`.
            tree.map_structure(operator.imul, total_discount, self_discount)

            # Equivalent to: `n_step_return += step.reward * total_discount`.
            tree.map_structure(lambda nsr, sr, td: operator.iadd(nsr, sr * td),
                               n_step_return, step_reward, total_discount)

            # Equivalent to: `total_discount *= step.discount`.
            tree.map_structure(operator.imul, total_discount, step_discount)

        transition = types.Transition(observation=observation,
                                      action=action,
                                      reward=n_step_return,
                                      discount=total_discount,
                                      next_observation=next_observation,
                                      extras=extras)

        # Create a list of steps.
        if self._final_step_placeholder is None:
            # utils.final_step_like is expensive (around 0.085ms) to run every time
            # so we cache its output.
            self._final_step_placeholder = utils.final_step_like(
                self._buffer[0], next_observation)
        final_step: base.Step = self._final_step_placeholder._replace(
            observation=next_observation)
        steps = list(self._buffer) + [final_step]

        # Calculate the priority for this transition.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      steps)

        # Insert the transition into replay along with its priority.
        self._writer.append(transition)
        for table, priority in table_priorities.items():
            self._writer.create_item(table=table,
                                     num_timesteps=1,
                                     priority=priority)