def _write_last(self): if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length: history = self._writer.history padding_step = dict(observation=history['observation'], action=history['action'], reward=history['reward'], discount=history['discount'], extras=history.get('extras', ())) # Get shapes and dtypes from the last element. padding_step = tree.map_structure( lambda col: self._padding_fn(col[-1].shape, col[-1].dtype), padding_step) padding_step['start_of_episode'] = False while self._writer.episode_steps < self._max_sequence_length: self._writer.append(padding_step) trajectory = tree.map_structure(lambda x: x[:], self._writer.history) # Pack the history into a base.Step structure and get numpy converted # variant for priotiy computation. trajectory = base.Step(**trajectory) trajectory_np = tree.map_structure(lambda x: x.numpy(), trajectory) # Calculate the priority for this episode. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory_np) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory)
def final_step_like(step: base.Step, next_observation: types.NestedArray) -> base.Step: """Return a list of steps with the final step zero-filled.""" # Make zero-filled components so we can fill out the last step. zero_action, zero_reward, zero_discount, zero_extras = tree.map_structure( zeros_like, (step.action, step.reward, step.discount, step.extras)) # Return a final step that only has next_observation. return base.Step(observation=next_observation, action=zero_action, reward=zero_reward, discount=zero_discount, extras=zero_extras)
def _maybe_create_item(self, sequence_length: int, *, end_of_episode: bool = False, force: bool = False): # Check conditions under which a new item is created. first_write = self._writer.episode_steps == sequence_length # NOTE(bshahr): the following line assumes that the only way sequence_length # is less than self._sequence_length, is if the episode is shorter than # self._sequence_length. period_reached = ( self._writer.episode_steps > self._sequence_length and ((self._writer.episode_steps - self._sequence_length) % self._period == 0)) if not first_write and not period_reached and not force: return # TODO(b/183945808): will need to change to adhere to the new protocol. if not end_of_episode: get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) else: get_traj = operator.itemgetter(slice(-sequence_length, None)) get_traj_np = lambda x: get_traj(x).numpy() history = self._writer.history trajectory = base.Step(**tree.map_structure(get_traj, history)) trajectory_np = base.Step(**tree.map_structure(get_traj_np, history)) # Compute priorities for the buffer. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory_np) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory)
def signature(cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = (), sequence_length: Optional[int] = None): """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. Args: environment_spec: A `specs.EnvironmentSpec` whose fields are nested structures with leaf nodes that have `.shape` and `.dtype` attributes. This should come from the environment that will be used to generate the data inserted into the Reverb table. extras_spec: A nested structure with leaf nodes that have `.shape` and `.dtype` attributes. The structure (and shapes/dtypes) of this must be the same as the `extras` passed into `ReverbAdder.add`. sequence_length: An optional integer representing the expected length of sequences that will be added to replay. Returns: A `Step` whose leaf nodes are `tf.TensorSpec` objects. """ def add_time_dim(paths: Iterable[str], spec: tf.TensorSpec): return tf.TensorSpec(shape=(sequence_length, *spec.shape), dtype=spec.dtype, name='/'.join(str(p) for p in paths)) trajectory_env_spec, trajectory_extras_spec = tree.map_structure_with_path( add_time_dim, (environment_spec, extras_spec)) spec_step = base.Step(*trajectory_env_spec, start_of_episode=tf.TensorSpec( shape=(sequence_length, ), dtype=tf.bool, name='start_of_episode'), extras=trajectory_extras_spec) return spec_step
def _write(self): # Convenient getters for use in tree operations. get_first = lambda x: x[self._first_idx] get_last = lambda x: x[self._last_idx] # Note: this getter is meant to be used on a TrajectoryWriter.history to # obtain its numpy values. get_all_np = lambda x: x[self._first_idx:self._last_idx].numpy() # Get the state, action, next_state, as well as possibly extras for the # transition that is about to be written. s, a = tree.map_structure(get_first, (self._writer.history['observation'], self._writer.history['action'])) s_ = tree.map_structure(get_last, self._writer.history['observation']) # Maybe get extras to add to the transition later. if 'extras' in self._writer.history: extras = tree.map_structure(get_first, self._writer.history['extras']) # Note: at the beginning of an episode we will add the initial N-1 # transitions (of size 1, 2, ...) and at the end of an episode (when # called from write_last) we will write the final transitions of size (N, # N-1, ...). See the Note in the docstring. # Get numpy view of the steps to be fed into the priority functions. history_np = tree.map_structure( get_all_np, { k: v for k, v in self._writer.history.items() if k in base.Step._fields }) # Compute discounted return and geometric discount over n steps. n_step_return, total_discount = self._compute_cumulative_quantities( history_np) # Append the computed n-step return and total discount. # Note: if this call to _write() is within a call to _write_last(), then # this is the only data being appended and so it is not a partial append. self._writer.append( dict(n_step_return=n_step_return, total_discount=total_discount), partial_step=self._writer.episode_steps <= self._last_idx) # Form the n-step transition by using the following: # the first observation and action in the buffer, along with the cumulative # reward and discount computed above. n_step_return, total_discount = tree.map_structure( lambda x: x[-1], (self._writer.history['n_step_return'], self._writer.history['total_discount'])) transition = types.Transition( observation=s, action=a, reward=n_step_return, discount=total_discount, next_observation=s_, extras=(extras if 'extras' in self._writer.history else ())) # Calculate the priority for this transition. table_priorities = utils.calculate_priorities(self._priority_fns, base.Step(**history_np)) # Insert the transition into replay along with its priority. for table, priority in table_priorities.items(): self._writer.create_item(table=table, priority=priority, trajectory=transition)