Пример #1
0
    def _write_last(self):
        if self._padding_fn is not None and self._writer.episode_steps < self._max_sequence_length:
            history = self._writer.history
            padding_step = dict(observation=history['observation'],
                                action=history['action'],
                                reward=history['reward'],
                                discount=history['discount'],
                                extras=history.get('extras', ()))
            # Get shapes and dtypes from the last element.
            padding_step = tree.map_structure(
                lambda col: self._padding_fn(col[-1].shape, col[-1].dtype),
                padding_step)
            padding_step['start_of_episode'] = False
            while self._writer.episode_steps < self._max_sequence_length:
                self._writer.append(padding_step)

        trajectory = tree.map_structure(lambda x: x[:], self._writer.history)

        # Pack the history into a base.Step structure and get numpy converted
        # variant for priotiy computation.
        trajectory = base.Step(**trajectory)
        trajectory_np = tree.map_structure(lambda x: x.numpy(), trajectory)

        # Calculate the priority for this episode.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory_np)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
Пример #2
0
  def _maybe_add_priorities(self):
    if not (
        # Write the first time we hit the max sequence length...
        self._step == self._max_sequence_length or
        # ... or every `period`th time after hitting max length.
        (self._step > self._max_sequence_length and
         (self._step - self._max_sequence_length) % self._period == 0)):
      return

    # Compute priorities for the buffer.
    steps = list(self._buffer)
    num_steps = len(steps)
    table_priorities = utils.calculate_priorities(self._priority_fns, steps)

    # Create a prioritized item for each table.
    for table_name, priority in table_priorities.items():
      self._writer.create_item(table_name, num_steps, priority)
Пример #3
0
    def _write(self):
        # NOTE: we do not check that the buffer is of length N here. This means
        # that at the beginning of an episode we will add the initial N-1
        # transitions (of size 1, 2, ...) and at the end of an episode (when
        # called from write_last) we will write the final transitions of size (N,
        # N-1, ...). See the Note in the docstring.

        # Form the n-step transition given the steps.
        observation = self._buffer[0].observation
        action = self._buffer[0].action
        extras = self._buffer[0].extras
        next_observation = self._next_observation

        # Initialize the n-step return and the discount accumulators. We make a
        # copy of the first reward/discount so that when we add/multiply in place
        # it won't change the actual reward or discount.
        n_step_return = copy.deepcopy(self._buffer[0].reward)
        total_discount = copy.deepcopy(self._buffer[0].discount)

        # NOTE: total discount will have one less discount than it does
        # step.discounts. This is so that when the learner/update uses an additional
        # discount we don't apply it twice. Inside the following loop we will
        # apply this right before summing up the n_step_return.
        for step in itertools.islice(self._buffer, 1, None):
            total_discount *= self._discount
            n_step_return += step.reward * total_discount
            total_discount *= step.discount

        transition = (observation, action, n_step_return, total_discount,
                      next_observation, extras)

        # Create a list of steps.
        final_step = utils.final_step_like(self._buffer[0], next_observation)
        steps = list(self._buffer) + [final_step]

        # Calculate the priority for this transition.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      steps)

        # Insert the transition into replay along with its priority.
        self._writer.append(transition)
        for table, priority in table_priorities.items():
            self._writer.create_item(table=table,
                                     num_timesteps=1,
                                     priority=priority)
Пример #4
0
    def _write_last(self):
        # Append a zero-filled final step.
        final_step = utils.final_step_like(self._buffer[0],
                                           self._next_observation)
        self._writer.append(final_step)

        # The length of the sequence we will be adding is the size of the buffer
        # plus one due to the final step.
        steps = list(self._buffer) + [final_step]
        num_steps = len(steps)

        # Calculate the priority for this episode.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      steps)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, num_steps, priority)
Пример #5
0
    def _maybe_create_item(self,
                           sequence_length: int,
                           *,
                           end_of_episode: bool = False,
                           force: bool = False):

        # Check conditions under which a new item is created.
        first_write = self._writer.episode_steps == sequence_length
        # NOTE(bshahr): the following line assumes that the only way sequence_length
        # is less than self._sequence_length, is if the episode is shorter than
        # self._sequence_length.
        period_reached = (
            self._writer.episode_steps > self._sequence_length
            and ((self._writer.episode_steps - self._sequence_length) %
                 self._period == 0))

        if not first_write and not period_reached and not force:
            return

        # TODO(b/183945808): will need to change to adhere to the new protocol.
        if not end_of_episode:
            get_traj = operator.itemgetter(slice(-sequence_length - 1, -1))
        else:
            get_traj = operator.itemgetter(slice(-sequence_length, None))

        get_traj_np = lambda x: get_traj(x).numpy()

        history = self._writer.history
        trajectory = base.Step(**tree.map_structure(get_traj, history))
        trajectory_np = base.Step(**tree.map_structure(get_traj_np, history))

        # Compute priorities for the buffer.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory_np)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
Пример #6
0
    def _write(self):
        # NOTE: we do not check that the buffer is of length N here. This means
        # that at the beginning of an episode we will add the initial N-1
        # transitions (of size 1, 2, ...) and at the end of an episode (when
        # called from write_last) we will write the final transitions of size (N,
        # N-1, ...). See the Note in the docstring.

        # Form the n-step transition given the steps.
        observation = self._buffer[0].observation
        action = self._buffer[0].action
        extras = self._buffer[0].extras
        next_observation = self._next_observation

        # Give the same tree structure to the n-step return accumulator,
        # n-step discount accumulator, and self.discount, so that they can be
        # iterated in parallel using tree.map_structure.
        (n_step_return, total_discount,
         self_discount) = tree_utils.broadcast_structures(
             self._buffer[0].reward, self._buffer[0].discount, self._discount)

        # Copy total_discount, so that accumulating into it doesn't affect
        # _buffer[0].discount.
        total_discount = tree.map_structure(np.copy, total_discount)

        # Broadcast n_step_return to have the broadcasted shape of
        # reward * discount. Also copy, to avoid accumulating into
        # _buffer[0].reward.
        n_step_return = tree.map_structure(
            lambda r, d: np.copy(np.broadcast_to(r,
                                                 np.broadcast(r, d).shape)),
            n_step_return, total_discount)

        # NOTE: total discount will have one less discount than it does
        # step.discounts. This is so that when the learner/update uses an additional
        # discount we don't apply it twice. Inside the following loop we will
        # apply this right before summing up the n_step_return.
        for step in itertools.islice(self._buffer, 1, None):
            (step_discount, step_reward,
             total_discount) = tree_utils.broadcast_structures(
                 step.discount, step.reward, total_discount)

            # Equivalent to: `total_discount *= self._discount`.
            tree.map_structure(operator.imul, total_discount, self_discount)

            # Equivalent to: `n_step_return += step.reward * total_discount`.
            tree.map_structure(lambda nsr, sr, td: operator.iadd(nsr, sr * td),
                               n_step_return, step_reward, total_discount)

            # Equivalent to: `total_discount *= step.discount`.
            tree.map_structure(operator.imul, total_discount, step_discount)

        transition = types.Transition(observation=observation,
                                      action=action,
                                      reward=n_step_return,
                                      discount=total_discount,
                                      next_observation=next_observation,
                                      extras=extras)

        # Create a list of steps.
        if self._final_step_placeholder is None:
            # utils.final_step_like is expensive (around 0.085ms) to run every time
            # so we cache its output.
            self._final_step_placeholder = utils.final_step_like(
                self._buffer[0], next_observation)
        final_step: base.Step = self._final_step_placeholder._replace(
            observation=next_observation)
        steps = list(self._buffer) + [final_step]

        # Calculate the priority for this transition.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      steps)

        # Insert the transition into replay along with its priority.
        self._writer.append(transition)
        for table, priority in table_priorities.items():
            self._writer.create_item(table=table,
                                     num_timesteps=1,
                                     priority=priority)
Пример #7
0
    def _write(self):
        # Convenient getters for use in tree operations.
        get_first = lambda x: x[self._first_idx]
        get_last = lambda x: x[self._last_idx]
        # Note: this getter is meant to be used on a TrajectoryWriter.history to
        # obtain its numpy values.
        get_all_np = lambda x: x[self._first_idx:self._last_idx].numpy()

        # Get the state, action, next_state, as well as possibly extras for the
        # transition that is about to be written.
        s, a = tree.map_structure(get_first,
                                  (self._writer.history['observation'],
                                   self._writer.history['action']))
        s_ = tree.map_structure(get_last, self._writer.history['observation'])

        # Maybe get extras to add to the transition later.
        if 'extras' in self._writer.history:
            extras = tree.map_structure(get_first,
                                        self._writer.history['extras'])

        # Note: at the beginning of an episode we will add the initial N-1
        # transitions (of size 1, 2, ...) and at the end of an episode (when
        # called from write_last) we will write the final transitions of size (N,
        # N-1, ...). See the Note in the docstring.
        # Get numpy view of the steps to be fed into the priority functions.
        history_np = tree.map_structure(
            get_all_np, {
                k: v
                for k, v in self._writer.history.items()
                if k in base.Step._fields
            })

        # Compute discounted return and geometric discount over n steps.
        n_step_return, total_discount = self._compute_cumulative_quantities(
            history_np)

        # Append the computed n-step return and total discount.
        # Note: if this call to _write() is within a call to _write_last(), then
        # this is the only data being appended and so it is not a partial append.
        self._writer.append(
            dict(n_step_return=n_step_return, total_discount=total_discount),
            partial_step=self._writer.episode_steps <= self._last_idx)

        # Form the n-step transition by using the following:
        # the first observation and action in the buffer, along with the cumulative
        # reward and discount computed above.
        n_step_return, total_discount = tree.map_structure(
            lambda x: x[-1], (self._writer.history['n_step_return'],
                              self._writer.history['total_discount']))
        transition = types.Transition(
            observation=s,
            action=a,
            reward=n_step_return,
            discount=total_discount,
            next_observation=s_,
            extras=(extras if 'extras' in self._writer.history else ()))

        # Calculate the priority for this transition.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      base.Step(**history_np))

        # Insert the transition into replay along with its priority.
        for table, priority in table_priorities.items():
            self._writer.create_item(table=table,
                                     priority=priority,
                                     trajectory=transition)