def _write_last(self): if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length: history = self._writer.history padding_step = dict(observation=history['observation'], action=history['action'], reward=history['reward'], discount=history['discount'], extras=history.get('extras', ())) # Get shapes and dtypes from the last element. padding_step = tree.map_structure( lambda col: self._padding_fn(col[-1].shape, col[-1].dtype), padding_step) padding_step['start_of_episode'] = False while self._writer.episode_steps < self._max_sequence_length: self._writer.append(padding_step) trajectory = tree.map_structure(lambda x: x[:], self._writer.history) # Pack the history into a base.Step structure and get numpy converted # variant for priotiy computation. trajectory = base.Trajectory(**trajectory) # Calculate the priority for this episode. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory) self._writer.flush(self._max_in_flight_items)
def signature(cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = (), sequence_length: Optional[int] = None): """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. Args: environment_spec: A `specs.EnvironmentSpec` whose fields are nested structures with leaf nodes that have `.shape` and `.dtype` attributes. This should come from the environment that will be used to generate the data inserted into the Reverb table. extras_spec: A nested structure with leaf nodes that have `.shape` and `.dtype` attributes. The structure (and shapes/dtypes) of this must be the same as the `extras` passed into `ReverbAdder.add`. sequence_length: An optional integer representing the expected length of sequences that will be added to replay. Returns: A `Trajectory` whose leaf nodes are `tf.TensorSpec` objects. """ def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): return tf.TensorSpec(shape=(sequence_length, *spec.shape), dtype=spec.dtype, name='/'.join(str(p) for p in paths)) trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( add_time_dim, (environment_spec, extras_spec)) spec_step = base.Trajectory(*trajectory_env_spec, start_of_episode=tf.TensorSpec( shape=(sequence_length, ), dtype=tf.bool, name='start_of_episode'), extras=trajectory_extras_spec) return spec_step
def _maybe_create_item(self, sequence_length: int, *, end_of_episode: bool = False, force: bool = False): # Check conditions under which a new item is created. first_write = self._writer.episode_steps == sequence_length # NOTE(bshahr): the following line assumes that the only way sequence_length # is less than self._sequence_length, is if the episode is shorter than # self._sequence_length. period_reached = ( self._writer.episode_steps > self._sequence_length and ((self._writer.episode_steps - self._sequence_length) % self._period == 0)) if not first_write and not period_reached and not force: return # TODO(b/183945808): will need to change to adhere to the new protocol. if not end_of_episode: get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) else: get_traj = operator.itemgetter(slice(-sequence_length, None)) history = self._writer.history trajectory = base.Trajectory(**tree.map_structure(get_traj, history)) # Compute priorities for the buffer. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory) self._writer.flush(self._max_in_flight_items)