Exemple #1
0
    def _write_last(self):
        if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length:
            history = self._writer.history
            padding_step = dict(observation=history['observation'],
                                action=history['action'],
                                reward=history['reward'],
                                discount=history['discount'],
                                extras=history.get('extras', ()))
            # Get shapes and dtypes from the last element.
            padding_step = tree.map_structure(
                lambda col: self._padding_fn(col[-1].shape, col[-1].dtype),
                padding_step)
            padding_step['start_of_episode'] = False
            while self._writer.episode_steps < self._max_sequence_length:
                self._writer.append(padding_step)

        trajectory = tree.map_structure(lambda x: x[:], self._writer.history)

        # Pack the history into a base.Step structure and get numpy converted
        # variant for priotiy computation.
        trajectory = base.Trajectory(**trajectory)

        # Calculate the priority for this episode.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
            self._writer.flush(self._max_in_flight_items)
Exemple #2
0
    def signature(cls,
                  environment_spec: specs.EnvironmentSpec,
                  extras_spec: types.NestedSpec = (),
                  sequence_length: Optional[int] = None):
        """This is a helper method for generating signatures for Reverb tables.

    Signatures are useful for validating data types and shapes, see Reverb's
    documentation for details on how they are used.

    Args:
      environment_spec: A `specs.EnvironmentSpec` whose fields are nested
        structures with leaf nodes that have `.shape` and `.dtype` attributes.
        This should come from the environment that will be used to generate
        the data inserted into the Reverb table.
      extras_spec: A nested structure with leaf nodes that have `.shape` and
        `.dtype` attributes. The structure (and shapes/dtypes) of this must
        be the same as the `extras` passed into `ReverbAdder.add`.
      sequence_length: An optional integer representing the expected length of
        sequences that will be added to replay.

    Returns:
      A `Trajectory` whose leaf nodes are `tf.TensorSpec` objects.
    """
        def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec):
            return tf.TensorSpec(shape=(sequence_length, *spec.shape),
                                 dtype=spec.dtype,
                                 name='/'.join(str(p) for p in paths))

        trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path(
            add_time_dim, (environment_spec, extras_spec))

        spec_step = base.Trajectory(*trajectory_env_spec,
                                    start_of_episode=tf.TensorSpec(
                                        shape=(sequence_length, ),
                                        dtype=tf.bool,
                                        name='start_of_episode'),
                                    extras=trajectory_extras_spec)

        return spec_step
Exemple #3
0
    def _maybe_create_item(self,
                           sequence_length: int,
                           *,
                           end_of_episode: bool = False,
                           force: bool = False):

        # Check conditions under which a new item is created.
        first_write = self._writer.episode_steps == sequence_length
        # NOTE(bshahr): the following line assumes that the only way sequence_length
        # is less than self._sequence_length, is if the episode is shorter than
        # self._sequence_length.
        period_reached = (
            self._writer.episode_steps > self._sequence_length
            and ((self._writer.episode_steps - self._sequence_length) %
                 self._period == 0))

        if not first_write and not period_reached and not force:
            return

        # TODO(b/183945808): will need to change to adhere to the new protocol.
        if not end_of_episode:
            get_traj = operator.itemgetter(slice(-sequence_length - 1, -1))
        else:
            get_traj = operator.itemgetter(slice(-sequence_length, None))

        history = self._writer.history
        trajectory = base.Trajectory(**tree.map_structure(get_traj, history))

        # Compute priorities for the buffer.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
            self._writer.flush(self._max_in_flight_items)