Пример #1
0
    def _write_last(self):
        if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length:
            history = self._writer.history
            padding_step = dict(observation=history['observation'],
                                action=history['action'],
                                reward=history['reward'],
                                discount=history['discount'],
                                extras=history.get('extras', ()))
            # Get shapes and dtypes from the last element.
            padding_step = tree.map_structure(
                lambda col: self._padding_fn(col[-1].shape, col[-1].dtype),
                padding_step)
            padding_step['start_of_episode'] = False
            while self._writer.episode_steps < self._max_sequence_length:
                self._writer.append(padding_step)

        trajectory = tree.map_structure(lambda x: x[:], self._writer.history)

        # Pack the history into a base.Step structure and get numpy converted
        # variant for priotiy computation.
        trajectory = base.Step(**trajectory)
        trajectory_np = tree.map_structure(lambda x: x.numpy(), trajectory)

        # Calculate the priority for this episode.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory_np)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
Пример #2
0
def final_step_like(step: base.Step,
                    next_observation: types.NestedArray) -> base.Step:
    """Return a list of steps with the final step zero-filled."""
    # Make zero-filled components so we can fill out the last step.
    zero_action, zero_reward, zero_discount, zero_extras = tree.map_structure(
        zeros_like, (step.action, step.reward, step.discount, step.extras))

    # Return a final step that only has next_observation.
    return base.Step(observation=next_observation,
                     action=zero_action,
                     reward=zero_reward,
                     discount=zero_discount,
                     extras=zero_extras)
Пример #3
0
    def _maybe_create_item(self,
                           sequence_length: int,
                           *,
                           end_of_episode: bool = False,
                           force: bool = False):

        # Check conditions under which a new item is created.
        first_write = self._writer.episode_steps == sequence_length
        # NOTE(bshahr): the following line assumes that the only way sequence_length
        # is less than self._sequence_length, is if the episode is shorter than
        # self._sequence_length.
        period_reached = (
            self._writer.episode_steps > self._sequence_length
            and ((self._writer.episode_steps - self._sequence_length) %
                 self._period == 0))

        if not first_write and not period_reached and not force:
            return

        # TODO(b/183945808): will need to change to adhere to the new protocol.
        if not end_of_episode:
            get_traj = operator.itemgetter(slice(-sequence_length - 1, -1))
        else:
            get_traj = operator.itemgetter(slice(-sequence_length, None))

        get_traj_np = lambda x: get_traj(x).numpy()

        history = self._writer.history
        trajectory = base.Step(**tree.map_structure(get_traj, history))
        trajectory_np = base.Step(**tree.map_structure(get_traj_np, history))

        # Compute priorities for the buffer.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory_np)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
Пример #4
0
    def signature(cls,
                  environment_spec: specs.EnvironmentSpec,
                  extras_spec: types.NestedSpec = (),
                  sequence_length: Optional[int] = None):
        """This is a helper method for generating signatures for Reverb tables.

    Signatures are useful for validating data types and shapes, see Reverb's
    documentation for details on how they are used.

    Args:
      environment_spec: A `specs.EnvironmentSpec` whose fields are nested
        structures with leaf nodes that have `.shape` and `.dtype` attributes.
        This should come from the environment that will be used to generate
        the data inserted into the Reverb table.
      extras_spec: A nested structure with leaf nodes that have `.shape` and
        `.dtype` attributes. The structure (and shapes/dtypes) of this must
        be the same as the `extras` passed into `ReverbAdder.add`.
      sequence_length: An optional integer representing the expected length of
        sequences that will be added to replay.

    Returns:
      A `Step` whose leaf nodes are `tf.TensorSpec` objects.
    """
        def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec):
            return tf.TensorSpec(shape=(sequence_length, *spec.shape),
                                 dtype=spec.dtype,
                                 name='/'.join(str(p) for p in paths))

        trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path(
            add_time_dim, (environment_spec, extras_spec))

        spec_step = base.Step(*trajectory_env_spec,
                              start_of_episode=tf.TensorSpec(
                                  shape=(sequence_length, ),
                                  dtype=tf.bool,
                                  name='start_of_episode'),
                              extras=trajectory_extras_spec)

        return spec_step
Пример #5
0
    def _write(self):
        # Convenient getters for use in tree operations.
        get_first = lambda x: x[self._first_idx]
        get_last = lambda x: x[self._last_idx]
        # Note: this getter is meant to be used on a TrajectoryWriter.history to
        # obtain its numpy values.
        get_all_np = lambda x: x[self._first_idx:self._last_idx].numpy()

        # Get the state, action, next_state, as well as possibly extras for the
        # transition that is about to be written.
        s, a = tree.map_structure(get_first,
                                  (self._writer.history['observation'],
                                   self._writer.history['action']))
        s_ = tree.map_structure(get_last, self._writer.history['observation'])

        # Maybe get extras to add to the transition later.
        if 'extras' in self._writer.history:
            extras = tree.map_structure(get_first,
                                        self._writer.history['extras'])

        # Note: at the beginning of an episode we will add the initial N-1
        # transitions (of size 1, 2, ...) and at the end of an episode (when
        # called from write_last) we will write the final transitions of size (N,
        # N-1, ...). See the Note in the docstring.
        # Get numpy view of the steps to be fed into the priority functions.
        history_np = tree.map_structure(
            get_all_np, {
                k: v
                for k, v in self._writer.history.items()
                if k in base.Step._fields
            })

        # Compute discounted return and geometric discount over n steps.
        n_step_return, total_discount = self._compute_cumulative_quantities(
            history_np)

        # Append the computed n-step return and total discount.
        # Note: if this call to _write() is within a call to _write_last(), then
        # this is the only data being appended and so it is not a partial append.
        self._writer.append(
            dict(n_step_return=n_step_return, total_discount=total_discount),
            partial_step=self._writer.episode_steps <= self._last_idx)

        # Form the n-step transition by using the following:
        # the first observation and action in the buffer, along with the cumulative
        # reward and discount computed above.
        n_step_return, total_discount = tree.map_structure(
            lambda x: x[-1], (self._writer.history['n_step_return'],
                              self._writer.history['total_discount']))
        transition = types.Transition(
            observation=s,
            action=a,
            reward=n_step_return,
            discount=total_discount,
            next_observation=s_,
            extras=(extras if 'extras' in self._writer.history else ()))

        # Calculate the priority for this transition.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      base.Step(**history_np))

        # Insert the transition into replay along with its priority.
        for table, priority in table_priorities.items():
            self._writer.create_item(table=table,
                                     priority=priority,
                                     trajectory=transition)