Esempio n. 1
0
    def update(
        self,
        policy_state: types.NestedTensor,
        trajectories: Trajectory,
        number_of_particles: int,
    ) -> types.NestedTensor:
        """
        Update the policy state at the end of each iteration.

        Note that the each of the trajectories in the batch should be of the same length.
        Trajectories cannot terminate and restart.

        :param policy_state: A nest of tensors with details about policy.
        :param trajectories: A time-stacked trajectory object.
        :param number_of_particles: Number of monte-carlo rollouts of each action trajectory.
        """
        assert (self._num_elites <= trajectories.discount.shape[0]
                ), "num_elites needs to be smaller than population size"
        assert tf.equal(
            tf.reduce_all(trajectories.is_boundary()[:, :-1]), False
        ), "No trajectories in the batch should contain a terminal state before the final step."
        assert tf.equal(
            tf.reduce_all(trajectories.is_boundary()[:, -1]), True
        ), "All trajectories in the batch must end in a terminal state."

        returns = averaged_particle_returns(trajectories.reward,
                                            trajectories.discount,
                                            number_of_particles)

        sorted_idx = tf.argsort(returns, direction="DESCENDING")
        elite_idx = sorted_idx[:self._num_elites]
        elites = tf.gather(
            trajectories.action, elite_idx
        )  # shape = (number of elites, horizon) + action_spec.shape

        elites_mean = tf.reduce_mean(
            elites, axis=0)  # shape = (horizon,) + action_spec.shape
        elites_var = tf.reduce_mean(
            tf.math.square(elites - elites_mean),
            axis=0)  # shape = (horizon,) + action_spec.shape

        old_mean, old_var, low, high, _, step_index = policy_state

        new_mean = (
            1.0 - self._lr
        ) * old_mean + self._lr * elites_mean  # shape = (horizon,) + action_spec.shape
        new_var = (
            1.0 - self._lr
        ) * old_var + self._lr * elites_var  # shape = (horizon,) + action_spec.shape

        new_actions = sample_action_batch(new_mean, new_var, low, high,
                                          returns.shape[0])

        return tf.nest.pack_sequence_as(
            policy_state,
            [
                new_mean, new_var, low, high, new_actions,
                tf.zeros_like(step_index)
            ],
        )
Esempio n. 2
0
    def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:
        """Cache the single step trajectory to be written into Reverb.

    Allows trajectory to be a flattened trajectory. No batch dimension allowed.

    Args:
      trajectory: The trajectory to be written which could be (possibly nested)
        trajectory object or a flattened version of a trajectory. It assumes
        there is *no* batch dimension.

    Raises:
      ValueError: If `bypass_partial_episodes` == False and episode length
        is > `max_sequence_length`.
    """
        # TODO(b/176494855): Raise an error if an invalid trajectory is passed in.
        # Currently, invalid `traj` value (mid->first, last->last) is not specially
        # handled and is treated as a normal mid->mid step.
        if (self._cached_steps >= self._max_sequence_length
                and not self._overflow_episode):
            self._overflow_episode = True
            if self._bypass_partial_episodes:
                logging.error(
                    "The number of trajectories within the same episode exceeds "
                    "`max_sequence_length`. This episode is bypassed and will NOT "
                    "be written into the replay buffer. Consider increasing the "
                    "`max_sequence_length`.")
            else:
                raise ValueError(
                    "The number of trajectories within the same episode "
                    "exceeds `max_sequence_length`. Consider increasing the "
                    "`max_sequence_length` or set `bypass_partial_episodes` to true "
                    "to bypass the episodes with length more than "
                    "`max_sequence_length`.")

        # At the end of the overflowing episode, drop the cached incomplete episode
        # and reset the writer.
        if self._overflow_episode and trajectory.is_boundary():
            self.reset(write_cached_steps=False)
            return

        if not self._overflow_episode:
            self._writer.append(trajectory)
            self._writer_has_data = True
            self._cached_steps += 1

            # At the end of an episode, write the item to Reverb and clear the cache.
            if trajectory.is_boundary():
                self.reset(write_cached_steps=True)
Esempio n. 3
0
    def __call__(self, trajectory: Trajectory):
        if not trajectory.is_boundary():
            self.step_counter += 1
        else:
            self.episode_couinter += 1

        if self.step_counter % self.log_period == 0:
            print(
                f"...Step {self.step_counter:12} of Episode {self.episode_couinter+1:8}",
                end="\r",
            )
Esempio n. 4
0
def extract_transitions_from_trajectories(
    trajectory: Trajectory,
    observation_spec: TensorSpec,
    action_spec: TensorSpec,
    predict_state_difference: bool,
) -> Transition:
    """
    TF-Agents returns a batch of trajectories from a buffer as a `Trajectory` object. This function
    transforms the data in the batch into a `Transition` tuple which can be used used for training
    the model.

    :param trajectory: The TF-Agents trajectory object
    :param observation_spec: The `TensorSpec` object which defines the observation tensors
    :param action_spec: The `TensorSpec` object which defines the action tensors
    :param predict_state_difference: Boolean to specify whether the transition model should
        return the next (latent) state or the difference between the current (latent) state and
        the next (latent) state

    :return: A `Transition` tuple which contains the observations and actions which can be used to
            train the model.
    """
    mask = ~trajectory.is_boundary()[:, :-1]  # to filter out boundary elements

    trajectory_observation = trajectory.observation
    # [batch_size, time_dim, features...]
    tf.ensure_shape(trajectory_observation, [None, None] + observation_spec.shape)
    next_observation = tf.boolean_mask(trajectory_observation[:, 1:, ...], mask)
    observation = tf.boolean_mask(trajectory_observation[:, :-1, ...], mask)

    trajectory_action = trajectory.action
    # [batch_size, time_dim, features...]
    tf.ensure_shape(trajectory_action, [None, None] + action_spec.shape)
    action = tf.boolean_mask(trajectory_action[:, :-1, ...], mask)

    trajectory_reward = trajectory.reward
    # [batch_size, time_dim]
    tf.ensure_shape(trajectory_reward, [None, None])
    reward = tf.boolean_mask(trajectory_reward[:, :-1], mask)

    if predict_state_difference:
        next_observation -= observation

    return Transition(
        observation=observation,
        action=action,
        reward=reward,
        next_observation=next_observation,
    )
Esempio n. 5
0
def make_trajectory_mask(batched_traj: trajectory.Trajectory) -> types.Tensor:
  """Mask boundary trajectories and those with invalid returns and advantages.

  Args:
    batched_traj: Trajectory, doubly-batched [batch_dim, time_dim,...]. It must
      be preprocessed already.

  Returns:
    A mask, type tf.float32, that is 0.0 for all between-episode Trajectory
      (batched_traj.step_type is LAST) and 0.0 if the return value is
      unavailable.
  """
  # 1.0 for all valid trajectories. 0.0 where between episodes.
  not_between_episodes = ~batched_traj.is_boundary()

  # 1.0 for trajectories with valid return values. 0.0 where return and
  # advantage are both 0. This happens to the last item when the experience gets
  # preprocessed, as insufficient information was available for calculating
  # advantages.
  valid_return_value = ~(
      tf.equal(batched_traj.policy_info['return'], 0)
      & tf.equal(batched_traj.policy_info['normalized_advantage'], 0))

  return tf.cast(not_between_episodes & valid_return_value, tf.float32)
Esempio n. 6
0
    def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:
        """Writes the trajectory into the underlying replay buffer.

    Allows trajectory to be a flattened trajectory. No batch dimension allowed.

    Args:
      trajectory: The trajectory to be written which could be (possibly nested)
        trajectory object or a flattened version of a trajectory. It assumes
        there is *no* batch dimension.
    """
        self._last_trajectory = trajectory
        self._writer.append(trajectory)
        self._cached_steps += 1

        # If the fixed sequence length is reached, write the sequence.
        self._write_cached_steps()

        # If it happens to be the end of the episode, clear the cache. Pad first and
        # write the items into Reverb if required.
        if trajectory.is_boundary():
            if self._pad_end_of_episodes:
                self.reset(write_cached_steps=True)
            else:
                self.reset(write_cached_steps=False)
Esempio n. 7
0
    def call(self, trajectory: traj.Trajectory):
        if trajectory.step_type.ndim == 0:
            trajectory = nest_utils.batch_nested_array(trajectory)

        new_steps = np.sum((~trajectory.is_boundary()).astype(np.int64))
        self._np_state.environment_steps += new_steps
Esempio n. 8
0
    def total_loss(self,
                   experience: traj.Trajectory,
                   returns: types.Tensor,
                   weights: types.Tensor,
                   training: bool = False) -> tf_agent.LossInfo:
        # Ensure we see at least one full episode.
        time_steps = ts.TimeStep(experience.step_type,
                                 tf.zeros_like(experience.reward),
                                 tf.zeros_like(experience.discount),
                                 experience.observation)
        is_last = experience.is_last()
        num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32))
        tf.debugging.assert_greater(
            num_episodes,
            0.0,
            message=
            'No complete episode found. REINFORCE requires full episodes '
            'to compute losses.')

        # Mask out partial episodes at the end of each batch of time_steps.
        # NOTE: We use is_last rather than is_boundary because the last transition
        # is the transition with the last valid reward.  In other words, the
        # reward on the boundary transitions do not have valid rewards.  Since
        # REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping)
        # keeping the boundary transitions is irrelevant.
        valid_mask = tf.cast(experience.is_last(), dtype=tf.float32)
        valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True)
        valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32)
        if weights is not None:
            weights *= valid_mask
        else:
            weights = valid_mask

        advantages = returns
        value_preds = None

        if self._baseline:
            value_preds, _ = self._value_network(time_steps.observation,
                                                 time_steps.step_type,
                                                 training=True)
            if self._debug_summaries:
                tf.compat.v2.summary.histogram(name='value_preds',
                                               data=value_preds,
                                               step=self.train_step_counter)

        advantages = self._advantage_fn(returns, value_preds)
        if self._debug_summaries:
            tf.compat.v2.summary.histogram(name='advantages',
                                           data=advantages,
                                           step=self.train_step_counter)

        # TODO(b/126592060): replace with tensor normalizer.
        if self._normalize_returns:
            advantages = _standard_normalize(advantages, axes=(0, 1))
            if self._debug_summaries:
                tf.compat.v2.summary.histogram(
                    name='normalized_%s' %
                    ('advantages' if self._baseline else 'returns'),
                    data=advantages,
                    step=self.train_step_counter)

        nest_utils.assert_same_structure(time_steps, self.time_step_spec)
        policy_state = _get_initial_policy_state(self.collect_policy,
                                                 time_steps)
        actions_distribution = self.collect_policy.distribution(
            time_steps, policy_state=policy_state).action

        policy_gradient_loss = self.policy_gradient_loss(
            actions_distribution,
            experience.action,
            experience.is_boundary(),
            advantages,
            num_episodes,
            weights,
        )

        entropy_regularization_loss = self.entropy_regularization_loss(
            actions_distribution, weights)

        network_regularization_loss = tf.nn.scale_regularization_loss(
            self._actor_network.losses)

        total_loss = (policy_gradient_loss + network_regularization_loss +
                      entropy_regularization_loss)

        losses_dict = {
            'policy_gradient_loss': policy_gradient_loss,
            'policy_network_regularization_loss': network_regularization_loss,
            'entropy_regularization_loss': entropy_regularization_loss,
            'value_estimation_loss': 0.0,
            'value_network_regularization_loss': 0.0,
        }

        value_estimation_loss = None
        if self._baseline:
            value_estimation_loss = self.value_estimation_loss(
                value_preds, returns, num_episodes, weights)
            value_network_regularization_loss = tf.nn.scale_regularization_loss(
                self._value_network.losses)
            total_loss += value_estimation_loss + value_network_regularization_loss
            losses_dict['value_estimation_loss'] = value_estimation_loss
            losses_dict['value_network_regularization_loss'] = (
                value_network_regularization_loss)

        loss_info_extra = ReinforceAgentLossInfo(**losses_dict)

        losses_dict[
            'total_loss'] = total_loss  # Total loss not in loss_info_extra.

        common.summarize_scalar_dict(losses_dict,
                                     self.train_step_counter,
                                     name_scope='Losses/')

        return tf_agent.LossInfo(total_loss, loss_info_extra)