def _write_last(self): if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length: history = self._writer.history padding_step = dict(observation=history['observation'], action=history['action'], reward=history['reward'], discount=history['discount'], extras=history.get('extras', ())) # Get shapes and dtypes from the last element. padding_step = tree.map_structure( lambda col: self._padding_fn(col[-1].shape, col[-1].dtype), padding_step) padding_step['start_of_episode'] = False while self._writer.episode_steps < self._max_sequence_length: self._writer.append(padding_step) trajectory = tree.map_structure(lambda x: x[:], self._writer.history) # Pack the history into a base.Step structure and get numpy converted # variant for priotiy computation. trajectory = base.Step(**trajectory) trajectory_np = tree.map_structure(lambda x: x.numpy(), trajectory) # Calculate the priority for this episode. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory_np) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory)
def _maybe_add_priorities(self): if not ( # Write the first time we hit the max sequence length... self._step == self._max_sequence_length or # ... or every `period`th time after hitting max length. (self._step > self._max_sequence_length and (self._step - self._max_sequence_length) % self._period == 0)): return # Compute priorities for the buffer. steps = list(self._buffer) num_steps = len(steps) table_priorities = utils.calculate_priorities(self._priority_fns, steps) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, num_steps, priority)
def _write(self): # NOTE: we do not check that the buffer is of length N here. This means # that at the beginning of an episode we will add the initial N-1 # transitions (of size 1, 2, ...) and at the end of an episode (when # called from write_last) we will write the final transitions of size (N, # N-1, ...). See the Note in the docstring. # Form the n-step transition given the steps. observation = self._buffer[0].observation action = self._buffer[0].action extras = self._buffer[0].extras next_observation = self._next_observation # Initialize the n-step return and the discount accumulators. We make a # copy of the first reward/discount so that when we add/multiply in place # it won't change the actual reward or discount. n_step_return = copy.deepcopy(self._buffer[0].reward) total_discount = copy.deepcopy(self._buffer[0].discount) # NOTE: total discount will have one less discount than it does # step.discounts. This is so that when the learner/update uses an additional # discount we don't apply it twice. Inside the following loop we will # apply this right before summing up the n_step_return. for step in itertools.islice(self._buffer, 1, None): total_discount *= self._discount n_step_return += step.reward * total_discount total_discount *= step.discount transition = (observation, action, n_step_return, total_discount, next_observation, extras) # Create a list of steps. final_step = utils.final_step_like(self._buffer[0], next_observation) steps = list(self._buffer) + [final_step] # Calculate the priority for this transition. table_priorities = utils.calculate_priorities(self._priority_fns, steps) # Insert the transition into replay along with its priority. self._writer.append(transition) for table, priority in table_priorities.items(): self._writer.create_item(table=table, num_timesteps=1, priority=priority)
def _write_last(self): # Append a zero-filled final step. final_step = utils.final_step_like(self._buffer[0], self._next_observation) self._writer.append(final_step) # The length of the sequence we will be adding is the size of the buffer # plus one due to the final step. steps = list(self._buffer) + [final_step] num_steps = len(steps) # Calculate the priority for this episode. table_priorities = utils.calculate_priorities(self._priority_fns, steps) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, num_steps, priority)
def _maybe_create_item(self, sequence_length: int, *, end_of_episode: bool = False, force: bool = False): # Check conditions under which a new item is created. first_write = self._writer.episode_steps == sequence_length # NOTE(bshahr): the following line assumes that the only way sequence_length # is less than self._sequence_length, is if the episode is shorter than # self._sequence_length. period_reached = ( self._writer.episode_steps > self._sequence_length and ((self._writer.episode_steps - self._sequence_length) % self._period == 0)) if not first_write and not period_reached and not force: return # TODO(b/183945808): will need to change to adhere to the new protocol. if not end_of_episode: get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) else: get_traj = operator.itemgetter(slice(-sequence_length, None)) get_traj_np = lambda x: get_traj(x).numpy() history = self._writer.history trajectory = base.Step(**tree.map_structure(get_traj, history)) trajectory_np = base.Step(**tree.map_structure(get_traj_np, history)) # Compute priorities for the buffer. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory_np) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory)
def _write(self): # NOTE: we do not check that the buffer is of length N here. This means # that at the beginning of an episode we will add the initial N-1 # transitions (of size 1, 2, ...) and at the end of an episode (when # called from write_last) we will write the final transitions of size (N, # N-1, ...). See the Note in the docstring. # Form the n-step transition given the steps. observation = self._buffer[0].observation action = self._buffer[0].action extras = self._buffer[0].extras next_observation = self._next_observation # Give the same tree structure to the n-step return accumulator, # n-step discount accumulator, and self.discount, so that they can be # iterated in parallel using tree.map_structure. (n_step_return, total_discount, self_discount) = tree_utils.broadcast_structures( self._buffer[0].reward, self._buffer[0].discount, self._discount) # Copy total_discount, so that accumulating into it doesn't affect # _buffer[0].discount. total_discount = tree.map_structure(np.copy, total_discount) # Broadcast n_step_return to have the broadcasted shape of # reward * discount. Also copy, to avoid accumulating into # _buffer[0].reward. n_step_return = tree.map_structure( lambda r, d: np.copy(np.broadcast_to(r, np.broadcast(r, d).shape)), n_step_return, total_discount) # NOTE: total discount will have one less discount than it does # step.discounts. This is so that when the learner/update uses an additional # discount we don't apply it twice. Inside the following loop we will # apply this right before summing up the n_step_return. for step in itertools.islice(self._buffer, 1, None): (step_discount, step_reward, total_discount) = tree_utils.broadcast_structures( step.discount, step.reward, total_discount) # Equivalent to: `total_discount *= self._discount`. tree.map_structure(operator.imul, total_discount, self_discount) # Equivalent to: `n_step_return += step.reward * total_discount`. tree.map_structure(lambda nsr, sr, td: operator.iadd(nsr, sr * td), n_step_return, step_reward, total_discount) # Equivalent to: `total_discount *= step.discount`. tree.map_structure(operator.imul, total_discount, step_discount) transition = types.Transition(observation=observation, action=action, reward=n_step_return, discount=total_discount, next_observation=next_observation, extras=extras) # Create a list of steps. if self._final_step_placeholder is None: # utils.final_step_like is expensive (around 0.085ms) to run every time # so we cache its output. self._final_step_placeholder = utils.final_step_like( self._buffer[0], next_observation) final_step: base.Step = self._final_step_placeholder._replace( observation=next_observation) steps = list(self._buffer) + [final_step] # Calculate the priority for this transition. table_priorities = utils.calculate_priorities(self._priority_fns, steps) # Insert the transition into replay along with its priority. self._writer.append(transition) for table, priority in table_priorities.items(): self._writer.create_item(table=table, num_timesteps=1, priority=priority)
def _write(self): # Convenient getters for use in tree operations. get_first = lambda x: x[self._first_idx] get_last = lambda x: x[self._last_idx] # Note: this getter is meant to be used on a TrajectoryWriter.history to # obtain its numpy values. get_all_np = lambda x: x[self._first_idx:self._last_idx].numpy() # Get the state, action, next_state, as well as possibly extras for the # transition that is about to be written. s, a = tree.map_structure(get_first, (self._writer.history['observation'], self._writer.history['action'])) s_ = tree.map_structure(get_last, self._writer.history['observation']) # Maybe get extras to add to the transition later. if 'extras' in self._writer.history: extras = tree.map_structure(get_first, self._writer.history['extras']) # Note: at the beginning of an episode we will add the initial N-1 # transitions (of size 1, 2, ...) and at the end of an episode (when # called from write_last) we will write the final transitions of size (N, # N-1, ...). See the Note in the docstring. # Get numpy view of the steps to be fed into the priority functions. history_np = tree.map_structure( get_all_np, { k: v for k, v in self._writer.history.items() if k in base.Step._fields }) # Compute discounted return and geometric discount over n steps. n_step_return, total_discount = self._compute_cumulative_quantities( history_np) # Append the computed n-step return and total discount. # Note: if this call to _write() is within a call to _write_last(), then # this is the only data being appended and so it is not a partial append. self._writer.append( dict(n_step_return=n_step_return, total_discount=total_discount), partial_step=self._writer.episode_steps <= self._last_idx) # Form the n-step transition by using the following: # the first observation and action in the buffer, along with the cumulative # reward and discount computed above. n_step_return, total_discount = tree.map_structure( lambda x: x[-1], (self._writer.history['n_step_return'], self._writer.history['total_discount'])) transition = types.Transition( observation=s, action=a, reward=n_step_return, discount=total_discount, next_observation=s_, extras=(extras if 'extras' in self._writer.history else ())) # Calculate the priority for this transition. table_priorities = utils.calculate_priorities(self._priority_fns, base.Step(**history_np)) # Insert the transition into replay along with its priority. for table, priority in table_priorities.items(): self._writer.create_item(table=table, priority=priority, trajectory=transition)