Ejemplo n.º 1
0
    def _train(self,
               experience,
               weights,
               episode_data=None,
               augmented_obs=None,
               augmented_next_obs=None):
        """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object. If augmentations > 1 then a
        tuple of the form: ``` (trajectory, [augmentation_1, ... ,
          augmentation_{K-1}]) ``` is expected.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      episode_data: Tuple of (episode, episode, metric) for contrastive loss.
      augmented_obs: List of length num_augmentations - 1 of random crops of the
        trajectory's observation.
      augmented_next_obs: List of length num_augmentations - 1 of random crops
        of the trajectory's next_observation.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """

        squeeze_time_dim = not self._critic_network_1.state_spec

        time_steps, policy_steps, next_time_steps = (
            trajectory.experience_to_transitions(experience, squeeze_time_dim))

        actions = policy_steps.action

        trainable_critic_variables = (
            self._critic_network_1.trainable_variables +
            self._critic_network_2.trainable_variables)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, (
                'No trainable critic variables to '
                'optimize.')
            tape.watch(trainable_critic_variables)

            critic_loss = self._critic_loss_weight * self.critic_loss(
                time_steps,
                actions,
                next_time_steps,
                augmented_obs,
                augmented_next_obs,
                td_errors_loss_fn=self._td_errors_loss_fn,
                gamma=self._gamma,
                reward_scale_factor=self._reward_scale_factor,
                weights=weights,
                training=True)

        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        total_loss = critic_loss
        actor_loss = tf.constant(0.0, tf.float32)
        alpha_loss = tf.constant(0.0, tf.float32)

        with tf.name_scope('Losses'):
            tf.compat.v2.summary.scalar(name='critic_loss',
                                        data=critic_loss,
                                        step=self.train_step_counter)

        # Only perform actor and alpha updates periodically
        if self.train_step_counter % self._actor_update_frequency == 0:
            trainable_actor_variables = self._actor_network.trainable_variables
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                assert trainable_actor_variables, (
                    'No trainable actor variables to '
                    'optimize.')
                tape.watch(trainable_actor_variables)
                actor_loss = self._actor_loss_weight * self.actor_loss(
                    time_steps, weights=weights)
            tf.debugging.check_numerics(actor_loss,
                                        'Actor loss is inf or nan.')
            actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
            self._apply_gradients(actor_grads, trainable_actor_variables,
                                  self._actor_optimizer)

            alpha_variable = [self._log_alpha]
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                assert alpha_variable, 'No alpha variable to optimize.'
                tape.watch(alpha_variable)
                alpha_loss = self._alpha_loss_weight * self.alpha_loss(
                    time_steps, weights=weights)
            tf.debugging.check_numerics(alpha_loss,
                                        'Alpha loss is inf or nan.')
            alpha_grads = tape.gradient(alpha_loss, alpha_variable)
            self._apply_gradients(alpha_grads, alpha_variable,
                                  self._alpha_optimizer)

            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='actor_loss',
                                            data=actor_loss,
                                            step=self.train_step_counter)
                tf.compat.v2.summary.scalar(name='alpha_loss',
                                            data=alpha_loss,
                                            step=self.train_step_counter)

            total_loss = critic_loss + actor_loss + alpha_loss

        # Contrastive loss for PSEs
        contrastive_loss = 0.0
        if self._contrastive_loss_weight > 0:
            contrastive_vars = self._actor_network.encoder_variables
            with tf.GradientTape(watch_accessed_variables=True,
                                 persistent=True) as tape:
                contrastive_loss = (self._contrastive_loss_weight *
                                    self.contrastive_metric_loss(episode_data))
            total_loss = total_loss + contrastive_loss
            tf.debugging.check_numerics(contrastive_loss,
                                        'Contrastive loss is inf or nan.')

            contrastive_grads = tape.gradient(contrastive_loss,
                                              contrastive_vars)
            self._apply_gradients(contrastive_grads, contrastive_vars,
                                  self._contrastive_optimizer)
            del tape

        self.train_step_counter.assign_add(1)
        self._update_target()

        # NOTE: Consider keeping track of previous actor/alpha loss.
        extra = SacContrastiveLossInfo(critic_loss=critic_loss,
                                       actor_loss=actor_loss,
                                       alpha_loss=alpha_loss,
                                       contrastive_loss=contrastive_loss)

        return tf_agent.LossInfo(loss=total_loss, extra=extra)
Ejemplo n.º 2
0
    def _loss(self, experience, weights=None):
        """Computes loss for behavioral cloning.

    Args:
      experience: A `Trajectory` containing experience.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.

    Returns:
      loss: A `LossInfo` struct.

    Raises:
      ValueError:
        If the number of actions is greater than 1.
    """
        with tf.name_scope('loss'):
            if self._nested_actions:
                actions = experience.action
            else:
                actions = tf.nest.flatten(experience.action)[0]

            logits, _ = self._cloning_network(experience.observation,
                                              experience.step_type,
                                              training=True)

            error = self._loss_fn(logits, actions)
            error_dtype = tf.nest.flatten(error)[0].dtype
            boundary_weights = tf.cast(~experience.is_boundary(), error_dtype)
            error *= boundary_weights

            if nest_utils.is_batched_nested_tensors(experience.action,
                                                    self.action_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                error = tf.reduce_sum(input_tensor=error, axis=1)

            # Average across the elements of the batch.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.
            if weights is not None:
                error *= weights
            loss = tf.reduce_mean(input_tensor=error)

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar(name='loss',
                                            data=loss,
                                            step=self.train_step_counter)

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._cloning_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                common.generate_tensor_summaries('errors', error,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(loss,
                                     BehavioralCloningLossInfo(loss=error))
Ejemplo n.º 3
0
  def total_loss(self, time_steps, actions, returns, weights):
    # Ensure we see at least one full episode.
    is_last = time_steps.is_last()
    num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32))
    tf.debugging.assert_greater(
        num_episodes, 0.0,
        message='No complete episode found. REINFORCE requires full episodes '
        'to compute losses.')

    # Mask out partial episodes at the end of each batch of time_steps.
    valid_mask = tf.cast(is_last, dtype=tf.float32)
    valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True)
    valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32)
    if weights is not None:
      weights *= valid_mask
    else:
      weights = valid_mask

    advantages = returns
    if self._baseline:
      value_preds, _ = self._value_network(
          time_steps.observation, time_steps.step_type)
      advantages = returns - value_preds
      if self._debug_summaries:
        tf.compat.v2.summary.histogram(
            name='value_preds', data=value_preds, step=self.train_step_counter)
        tf.compat.v2.summary.histogram(
            name='advantages', data=advantages, step=self.train_step_counter)

    # TODO(b/126592060): replace with tensor normalizer.
    if self._normalize_returns:
      advantages = _standard_normalize(advantages, axes=(0, 1))
      if self._debug_summaries:
        tf.compat.v2.summary.histogram(
            name='normalized_%s'%'advantages' if self._baseline else 'returns',
            data=advantages,
            step=self.train_step_counter)

    tf.nest.assert_same_structure(time_steps, self.time_step_spec)
    policy_state = _get_initial_policy_state(self.collect_policy, time_steps)
    actions_distribution = self.collect_policy.distribution(
        time_steps, policy_state=policy_state).action

    policy_gradient_loss = self.policy_gradient_loss(actions_distribution,
                                                     actions,
                                                     is_last,
                                                     advantages,
                                                     num_episodes,
                                                     weights)
    entropy_regularization_loss = self.entropy_regularization_loss(
        actions_distribution, weights)

    total_loss = policy_gradient_loss + entropy_regularization_loss

    if self._baseline:
      value_estimation_loss = self.value_estimation_loss(
          value_preds, returns, num_episodes, weights)
      total_loss += value_estimation_loss

    with tf.name_scope('Losses/'):
      tf.compat.v2.summary.scalar(
          name='policy_gradient_loss',
          data=policy_gradient_loss,
          step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='entropy_regularization_loss',
          data=entropy_regularization_loss,
          step=self.train_step_counter)
      if self._baseline:
        tf.compat.v2.summary.scalar(
            name='value_estimation_loss',
            data=value_estimation_loss,
            step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='total_loss', data=total_loss, step=self.train_step_counter)

    return tf_agent.LossInfo(total_loss, ())
Ejemplo n.º 4
0
    def _train(self, experience, weights):
        """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """
        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        trainable_critic_variables = list(
            object_identity.ObjectIdentitySet(
                self._critic_network_1.trainable_variables +
                self._critic_network_2.trainable_variables))

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, (
                'No trainable critic variables to '
                'optimize.')
            tape.watch(trainable_critic_variables)
            critic_loss = self._critic_loss_with_optional_entropy_term(
                time_steps,
                actions,
                next_time_steps,
                td_errors_loss_fn=self._td_errors_loss_fn,
                gamma=self._gamma,
                reward_scale_factor=self._reward_scale_factor,
                weights=weights,
                training=True)
            critic_loss *= self._critic_loss_weight

            cql_alpha = self._get_cql_alpha()
            cql_loss = self._cql_loss(time_steps, actions, training=True)

            if self._bc_debug_mode:
                cql_critic_loss = cql_loss * cql_alpha
            else:
                cql_critic_loss = critic_loss + (cql_loss * cql_alpha)

        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        tf.debugging.check_numerics(cql_loss, 'CQL loss is inf or nan.')
        critic_grads = tape.gradient(cql_critic_loss,
                                     trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        trainable_actor_variables = self._actor_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_actor_variables, (
                'No trainable actor variables to '
                'optimize.')
            tape.watch(trainable_actor_variables)
            actor_loss = self._actor_loss_weight * self.actor_loss(
                time_steps, actions=actions, weights=weights)
        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
        actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
        self._apply_gradients(actor_grads, trainable_actor_variables,
                              self._actor_optimizer)

        alpha_variable = [self._log_alpha]
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert alpha_variable, 'No alpha variable to optimize.'
            tape.watch(alpha_variable)
            alpha_loss = self._alpha_loss_weight * self.alpha_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
        alpha_grads = tape.gradient(alpha_loss, alpha_variable)
        self._apply_gradients(alpha_grads, alpha_variable,
                              self._alpha_optimizer)

        # Based on the equation (24), which automates CQL alpha with the "budget"
        # parameter tau. CQL(H) is now CQL-Lagrange(H):
        # ```
        # min_Q max_{alpha >= 0} alpha * (log_sum_exp(Q(s, a')) - Q(s, a) - tau)
        # ```
        # If the expected difference in Q-values is less than tau, alpha
        # will adjust to be closer to 0. If the difference is higher than tau,
        # alpha is likely to take on high values and more aggressively penalize
        # Q-values.
        cql_alpha_loss = tf.constant(0.)
        if self._use_lagrange_cql_alpha:
            cql_alpha_variable = [self._log_cql_alpha]
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(cql_alpha_variable)
                cql_alpha_loss = -self._get_cql_alpha() * (cql_loss -
                                                           self._cql_tau)
            tf.debugging.check_numerics(cql_alpha_loss,
                                        'CQL alpha loss is inf or nan.')
            cql_alpha_gradients = tape.gradient(cql_alpha_loss,
                                                cql_alpha_variable)
            self._apply_gradients(cql_alpha_gradients, cql_alpha_variable,
                                  self._cql_alpha_optimizer)

        with tf.name_scope('Losses'):
            tf.compat.v2.summary.scalar(name='critic_loss',
                                        data=critic_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='actor_loss',
                                        data=actor_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='alpha_loss',
                                        data=alpha_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='cql_loss',
                                        data=cql_loss,
                                        step=self.train_step_counter)
            if self._use_lagrange_cql_alpha:
                tf.compat.v2.summary.scalar(name='cql_alpha_loss',
                                            data=cql_alpha_loss,
                                            step=self.train_step_counter)
        tf.compat.v2.summary.scalar(name='cql_alpha',
                                    data=cql_alpha,
                                    step=self.train_step_counter)
        tf.compat.v2.summary.scalar(name='sac_alpha',
                                    data=tf.exp(self._log_alpha),
                                    step=self.train_step_counter)

        self.train_step_counter.assign_add(1)
        self._update_target()

        total_loss = cql_critic_loss + actor_loss + alpha_loss

        extra = CqlSacLossInfo(critic_loss=critic_loss,
                               actor_loss=actor_loss,
                               alpha_loss=alpha_loss,
                               cql_loss=cql_loss,
                               cql_alpha=cql_alpha,
                               cql_alpha_loss=cql_alpha_loss)

        return tf_agent.LossInfo(loss=total_loss, extra=extra)
Ejemplo n.º 5
0
  def loss(self,
           observations,
           actions,
           rewards,
           weights=None,
           training=False):
    """Computes loss for reward prediction training.

    Args:
      observations: A batch of observations.
      actions: A batch of actions.
      rewards: A batch of rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether the loss is being used for training.

    Returns:
      loss: A `LossInfo` containing the loss for the training step.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    with tf.name_scope('loss'):
      sample_weights = weights if weights else 1
      if self._heteroscedastic:
        predictions, _ = self._reward_network(observations,
                                              training=training)
        predicted_values = predictions.q_value_logits
        predicted_log_variance = predictions.log_variance
        action_predicted_log_variance = common.index_with_actions(
            predicted_log_variance, tf.cast(actions, dtype=tf.int32))
        sample_weights = sample_weights * 0.5 * tf.exp(
            -action_predicted_log_variance)

        loss = 0.5 * tf.reduce_mean(action_predicted_log_variance)
        # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x)
        # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in
        # Bayesian Deep Learning for Computer Vision?." Advances in Neural
        # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977
      else:
        predicted_values, _ = self._reward_network(observations,
                                                   training=training)
        loss = tf.constant(0.0)

      action_predicted_values = common.index_with_actions(
          predicted_values,
          tf.cast(actions, dtype=tf.int32))

      # Apply Laplacian smoothing on the estimated rewards, if applicable.
      if self._laplacian_matrix is not None:
        smoothness_batched = tf.matmul(
            predicted_values,
            tf.matmul(self._laplacian_matrix, predicted_values,
                      transpose_b=True))
        loss += (self._laplacian_smoothing_weight * tf.reduce_mean(
            tf.linalg.tensor_diag_part(smoothness_batched) * sample_weights))

      loss += self._error_loss_fn(
          rewards,
          action_predicted_values,
          sample_weights,
          reduction=tf.compat.v1.losses.Reduction.MEAN)

    return tf_agent.LossInfo(loss, extra=())
Ejemplo n.º 6
0
  def get_epoch_loss(self, time_steps, actions, act_log_probs, returns,
                     normalized_advantages, action_distribution_parameters,
                     weights, train_step, debug_summaries):
    """Compute the loss and create optimization op for one training epoch.

    All tensors should have a single batch dimension.

    Args:
      time_steps: A minibatch of TimeStep tuples.
      actions: A minibatch of actions.
      act_log_probs: A minibatch of action probabilities (probability under the
        sampling policy).
      returns: A minibatch of per-timestep returns.
      normalized_advantages: A minibatch of normalized per-timestep advantages.
      action_distribution_parameters: Parameters of data-collecting action
        distribution. Needed for KL computation.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.  Includes a mask for invalid timesteps.
      train_step: A train_step variable to increment for each train step.
        Typically the global_step.
      debug_summaries: True if debug summaries should be created.

    Returns:
      A tf_agent.LossInfo named tuple with the total_loss and all intermediate
        losses in the extra field contained in a PPOLossInfo named tuple.
    """
    # Evaluate the current policy on timesteps.

    # batch_size from time_steps
    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
    policy_state = self._collect_policy.get_initial_state(batch_size)
    distribution_step = self._collect_policy.distribution(
        time_steps, policy_state)
    # TODO(eholly): Rename policy distributions to something clear and uniform.
    current_policy_distribution = distribution_step.action

    # Call all loss functions and add all loss values.
    value_estimation_loss = self.value_estimation_loss(time_steps, returns,
                                                       weights, debug_summaries)
    policy_gradient_loss = self.policy_gradient_loss(
        time_steps,
        actions,
        tf.stop_gradient(act_log_probs),
        tf.stop_gradient(normalized_advantages),
        current_policy_distribution,
        weights,
        debug_summaries=debug_summaries)

    if self._policy_l2_reg > 0.0 or self._value_function_l2_reg > 0.0:
      l2_regularization_loss = self.l2_regularization_loss(debug_summaries)
    else:
      l2_regularization_loss = tf.zeros_like(policy_gradient_loss)

    if self._entropy_regularization > 0.0:
      entropy_regularization_loss = self.entropy_regularization_loss(
          time_steps, current_policy_distribution, weights, debug_summaries)
    else:
      entropy_regularization_loss = tf.zeros_like(policy_gradient_loss)

    kl_penalty_loss = self.kl_penalty_loss(
        time_steps, action_distribution_parameters, current_policy_distribution,
        weights, debug_summaries)

    total_loss = (
        policy_gradient_loss + value_estimation_loss + l2_regularization_loss +
        entropy_regularization_loss + kl_penalty_loss)

    return tf_agent.LossInfo(
        total_loss,
        PPOLossInfo(
            policy_gradient_loss=policy_gradient_loss,
            value_estimation_loss=value_estimation_loss,
            l2_regularization_loss=l2_regularization_loss,
            entropy_regularization_loss=entropy_regularization_loss,
            kl_penalty_loss=kl_penalty_loss,
        ))
Ejemplo n.º 7
0
    def _train(self, experience, weights):
        """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """
        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        trainable_critic_variables = list(
            object_identity.ObjectIdentitySet(
                self._critic_network_1.trainable_variables +
                self._critic_network_2.trainable_variables))

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, (
                'No trainable critic variables to '
                'optimize.')
            tape.watch(trainable_critic_variables)
            critic_loss = self._critic_loss_weight * self.critic_loss(
                time_steps,
                actions,
                next_time_steps,
                td_errors_loss_fn=self._td_errors_loss_fn,
                gamma=self._gamma,
                reward_scale_factor=self._reward_scale_factor,
                weights=weights,
                training=True)

        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        trainable_actor_variables = self._actor_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_actor_variables, (
                'No trainable actor variables to '
                'optimize.')
            tape.watch(trainable_actor_variables)
            actor_loss = self._actor_loss_weight * self.actor_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
        actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
        self._apply_gradients(actor_grads, trainable_actor_variables,
                              self._actor_optimizer)

        alpha_variable = [self._log_alpha]
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert alpha_variable, 'No alpha variable to optimize.'
            tape.watch(alpha_variable)
            alpha_loss = self._alpha_loss_weight * self.alpha_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
        alpha_grads = tape.gradient(alpha_loss, alpha_variable)
        self._apply_gradients(alpha_grads, alpha_variable,
                              self._alpha_optimizer)

        with tf.name_scope('Losses'):
            tf.compat.v2.summary.scalar(name='critic_loss',
                                        data=critic_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='actor_loss',
                                        data=actor_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='alpha_loss',
                                        data=alpha_loss,
                                        step=self.train_step_counter)

        self.train_step_counter.assign_add(1)
        self._update_target()

        total_loss = critic_loss + actor_loss + alpha_loss

        extra = SacLossInfo(critic_loss=critic_loss,
                            actor_loss=actor_loss,
                            alpha_loss=alpha_loss)

        return tf_agent.LossInfo(loss=total_loss, extra=extra)
Ejemplo n.º 8
0
    def _train(self, experience, weights=None):
        # TODO(b/120034503): Move the conversion to transitions to the base class.
        time_steps, actions, next_time_steps = self._experience_to_transitions(
            experience)

        # TODO(kbanoop): Apply a loss mask or filter boundary transitions.
        critic_loss = self.critic_loss(time_steps,
                                       actions,
                                       next_time_steps,
                                       weights=weights)

        actor_loss = self.actor_loss(time_steps, weights=weights)

        def clip_and_summarize_gradients(grads_and_vars):
            """Clips gradients, and summarizes gradients and variables."""
            if self._gradient_clipping is not None:
                grads_and_vars = eager_utils.clip_gradient_norms_fn(
                    self._gradient_clipping)(grads_and_vars)

            if self._summarize_grads_and_vars:
                # TODO(kbanoop): Move gradient summaries to train_op after we switch to
                # eager train op, and move variable summaries to critic_loss.
                for grad, var in grads_and_vars:
                    with tf.name_scope('Gradients/'):
                        if grad is not None:
                            tf.compat.v2.summary.histogram(
                                name=grad.op.name,
                                data=grad,
                                step=self.train_step_counter)
                    with tf.name_scope('Variables/'):
                        if var is not None:
                            tf.compat.v2.summary.histogram(
                                name=var.op.name,
                                data=var,
                                step=self.train_step_counter)
            return grads_and_vars

        critic_train_op = eager_utils.create_train_op(
            critic_loss,
            self._critic_optimizer,
            global_step=self.train_step_counter,
            transform_grads_fn=clip_and_summarize_gradients,
            variables_to_train=self._critic_network_1.trainable_weights +
            self._critic_network_2.trainable_weights,
        )

        actor_train_op = eager_utils.create_train_op(
            actor_loss,
            self._actor_optimizer,
            global_step=None,
            transform_grads_fn=clip_and_summarize_gradients,
            variables_to_train=self._actor_network.trainable_weights,
        )

        with tf.control_dependencies([critic_train_op, actor_train_op]):
            update_targets_op = self._update_targets(
                self._target_update_tau, self._target_update_period)

        with tf.control_dependencies([update_targets_op]):
            total_loss = actor_loss + critic_loss

        # TODO(kbanoop): Compute per element TD loss and return in loss_info.
        return tf_agent.LossInfo(total_loss, Td3Info(actor_loss, critic_loss))
Ejemplo n.º 9
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None):
        """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_sequence_length` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        if self._n_step_update == 1:
            time_steps, actions, next_time_steps = self._experience_to_transitions(
                experience)
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, actions, _ = self._experience_to_transitions(
                first_two_steps)
            _, _, next_time_steps = self._experience_to_transitions(
                last_two_steps)

        with tf.name_scope('loss'):
            q_values = self._compute_q_values(time_steps, actions)

            next_q_values = self._compute_next_q_values(next_time_steps)

            if self._n_step_update == 1:
                # Special case for n = 1 to avoid a loss of performance.
                td_targets = compute_td_targets(
                    next_q_values,
                    rewards=reward_scale_factor * next_time_steps.reward,
                    discounts=gamma * next_time_steps.discount)
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                td_targets = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=next_q_values,
                    time_major=False,
                    provide_all_returns=False)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            if weights is not None:
                td_loss *= weights

            # Average across the elements of the batch.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.
            loss = tf.reduce_mean(input_tensor=td_loss)

            # Add network loss (such as regularization loss)
            if self._q_network.losses:
                loss = loss + tf.reduce_mean(self._q_network.losses)

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar(name='loss',
                                            data=loss,
                                            step=self.train_step_counter)

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Ejemplo n.º 10
0
 def testBaseLossInfo(self):
     loss_info = tf_agent.LossInfo(0.0, ())
     self.assertEqual(loss_info.loss, 0.0)
     self.assertIsInstance(loss_info, tf_agent.LossInfo)
Ejemplo n.º 11
0
    def _train(self, experience, weights=None):
        """Updates the policy based on the data in `experience`.

    Note that `experience` should only contain data points that this agent has
    not previously seen. If `experience` comes from a replay buffer, this buffer
    should be cleared between each call to `train`.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`.
      weights: Unused.

    Returns:
        A `LossInfo` containing the loss *before* the training step is taken.
        In most cases, if `weights` is provided, the entries of this tuple will
        have been calculated with the weights.  Note that each Agent chooses
        its own method of applying weights.
    """
        del weights  # unused

        # If the experience comes from a replay buffer, the reward has shape:
        #     [batch_size, time_steps]
        # where `time_steps` is the number of driver steps executed in each
        # training loop.
        # We flatten the tensors below in order to reflect the effective batch size.

        reward, _ = nest_utils.flatten_multi_batched_nested_tensors(
            experience.reward, self._time_step_spec.reward)
        action, _ = nest_utils.flatten_multi_batched_nested_tensors(
            experience.action, self._action_spec)
        observation, _ = nest_utils.flatten_multi_batched_nested_tensors(
            experience.observation, self._time_step_spec.observation)

        if self._observation_and_action_constraint_splitter is not None:
            observation, _ = self._observation_and_action_constraint_splitter(
                observation)
        observation = tf.reshape(observation, [-1, self._context_dim])
        observation = tf.cast(observation, self._dtype)
        reward = tf.cast(reward, self._dtype)

        for k in range(self._num_actions):
            diag_mask = tf.linalg.tensor_diag(
                tf.cast(tf.equal(action, k), self._dtype))
            observations_for_arm = tf.matmul(diag_mask, observation)
            rewards_for_arm = tf.matmul(diag_mask, tf.reshape(reward, [-1, 1]))

            num_samples_for_arm_current = tf.reduce_sum(diag_mask)
            tf.compat.v1.assign_add(self._num_samples_list[k],
                                    num_samples_for_arm_current)
            num_samples_for_arm_total = self._num_samples_list[k].read_value()

            # Update the matrix A and b.
            # pylint: disable=cell-var-from-loop,g-long-lambda
            def update(cov_matrix, data_vector):
                return update_a_and_b_with_forgetting(cov_matrix, data_vector,
                                                      rewards_for_arm,
                                                      observations_for_arm,
                                                      self._gamma,
                                                      self._use_eigendecomp)

            a_new, b_new, eig_vals, eig_matrix = tf.cond(
                tf.squeeze(num_samples_for_arm_total) > 0, lambda: update(
                    self._cov_matrix_list[k], self._data_vector_list[k]),
                lambda: (self._cov_matrix_list[k], self._data_vector_list[k],
                         self._eig_vals_list[k], self._eig_matrix_list[k]))

            tf.compat.v1.assign(self._cov_matrix_list[k], a_new)
            tf.compat.v1.assign(self._data_vector_list[k], b_new)
            tf.compat.v1.assign(self._eig_vals_list[k], eig_vals)
            tf.compat.v1.assign(self._eig_matrix_list[k], eig_matrix)

        loss = -1. * tf.reduce_sum(experience.reward)
        self.compute_summaries(loss)

        batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]),
                             dtype=tf.int64)
        self._train_step_counter.assign_add(batch_size)

        return tf_agent.LossInfo(loss=(loss), extra=())
Ejemplo n.º 12
0
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
      training: Whether the loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else utils.BatchSquash(rank))

            network_observation = time_steps.observation

            if self._observation_and_action_constraint_splitter is not None:
                network_observation, _ = (
                    self._observation_and_action_constraint_splitter(
                        network_observation))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(network_observation,
                                          time_steps.step_type,
                                          training=training)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            next_q_distribution = self._next_q_distribution(next_time_steps)

            if actions.shape.rank > 1:
                actions = tf.squeeze(actions,
                                     list(range(1, actions.shape.rank)))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.rank == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = tf.expand_dims(discount, -1)
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.rank == 1:
                    # See the explanation above.
                    reward = tf.expand_dims(reward, -1)
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                discounted_returns = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False,
                    provide_all_returns=False)

                # Convert discounted_returns from [batch_size] to [batch_size, 1]
                discounted_returns = tf.expand_dims(discounted_returns, -1)

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = tf.expand_dims(final_value_discount, -1)

                # Save the values of discounted_returns and final_value_discount in
                # order to check them in unit tests.
                self._discounted_returns = discounted_returns
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_returns,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(batch_size)
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.stack([indices, actions], axis=-1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_sum(
                    tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                    axis=1)
            else:
                critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                    labels=target_distribution, logits=chosen_action_logits)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'critic_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
Ejemplo n.º 13
0
  def _train(self, experience, weights):
    """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """
    time_steps, actions, next_time_steps = (
        self._experience_to_transitions(experience))

    trainable_critic_variables = (
        self._critic_network_1.trainable_variables +
        self._critic_network_2.trainable_variables)
    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert trainable_critic_variables, ('No trainable critic variables to '
                                          'optimize.')
      tape.watch(trainable_critic_variables)
      critic_loss = self.critic_loss(
          time_steps,
          actions,
          next_time_steps,
          td_errors_loss_fn=self._td_errors_loss_fn,
          gamma=self._gamma,
          reward_scale_factor=self._reward_scale_factor,
          weights=weights)

    tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
    critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
    self._apply_gradients(critic_grads, trainable_critic_variables,
                          self._critic_optimizer)

    trainable_actor_variables = self._actor_network.trainable_variables
    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert trainable_actor_variables, ('No trainable actor variables to '
                                         'optimize.')
      tape.watch(trainable_actor_variables)
      actor_loss = self.actor_loss(time_steps, weights=weights)
    tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
    actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
    self._apply_gradients(actor_grads, trainable_actor_variables,
                          self._actor_optimizer)

    alpha_variable = [self._log_alpha]
    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert alpha_variable, 'No alpha variable to optimize.'
      tape.watch(alpha_variable)
      alpha_loss = self.alpha_loss(time_steps, weights=weights)
    tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
    alpha_grads = tape.gradient(alpha_loss, alpha_variable)
    self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer)

    # updates safety critic if not training online
    safe_rew = next_time_steps.observation['task_agn_rew']
    sc_weight = None
    if self._fail_weight:
      sc_weight = tf.where(tf.cast(safe_rew, tf.bool), self._fail_weight / 0.5,
                           (1 - self._fail_weight) / 0.5)
    safety_critic_loss, lambda_loss = self.train_sc(
        experience, safe_rew, sc_weight,
        training=(not self._train_critic_online))

    with tf.name_scope('Losses'):
      tf.compat.v2.summary.scalar(
          name='critic_loss', data=critic_loss, step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='actor_loss', data=actor_loss, step=self.train_step_counter)
      tf.compat.v2.summary.scalar(
          name='alpha_loss', data=alpha_loss, step=self.train_step_counter)
      if lambda_loss is not None:
        tf.compat.v2.summary.scalar(
            name='lambda_loss', data=lambda_loss, step=self.train_step_counter)
      if safety_critic_loss is not None:
        tf.compat.v2.summary.scalar(
            name='safety_critic_loss',
            data=safety_critic_loss,
            step=self.train_step_counter)

    self.train_step_counter.assign_add(1)
    self._update_target()

    total_loss = critic_loss + actor_loss + alpha_loss

    extra = SafeSacLossInfo(
        critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss,
        safety_critic_loss=safety_critic_loss, lambda_loss=lambda_loss)

    return tf_agent.LossInfo(loss=total_loss, extra=extra)
Ejemplo n.º 14
0
    def testTrain(self, num_epochs, use_td_lambda_return):
        agent = ppo_agent.PPOAgent(self._time_step_spec,
                                   self._action_spec,
                                   tf.train.AdamOptimizer(),
                                   actor_net=DummyActorNet(
                                       self._action_spec, ),
                                   value_net=DummyValueNet(outer_rank=2),
                                   normalize_observations=False,
                                   num_epochs=num_epochs,
                                   use_gae=use_td_lambda_return,
                                   use_td_lambda_return=use_td_lambda_return)
        observations = tf.constant([
            [[1, 2], [3, 4], [5, 6]],
            [[1, 2], [3, 4], [5, 6]],
        ],
                                   dtype=tf.float32)

        time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * 2,
                                                       dtype=tf.int32),
                                 reward=tf.constant([[1] * 3] * 2,
                                                    dtype=tf.float32),
                                 discount=tf.constant([[1] * 3] * 2,
                                                      dtype=tf.float32),
                                 observation=observations)
        actions = tf.constant([[[0], [1], [1]], [[0], [1], [1]]],
                              dtype=tf.float32)

        action_distribution_parameters = {
            'loc': tf.constant([[[0.0]] * 3] * 2, dtype=tf.float32),
            'scale': tf.constant([[[1.0]] * 3] * 2, dtype=tf.float32),
        }

        policy_info = action_distribution_parameters

        experience = trajectory.Trajectory(time_steps.step_type, observations,
                                           actions, policy_info,
                                           time_steps.step_type,
                                           time_steps.reward,
                                           time_steps.discount)

        # Mock the build_train_op to return an op for incrementing this counter.
        counter = tf.train.get_or_create_global_step()
        zero = tf.constant(0, dtype=tf.float32)
        agent.build_train_op = (
            lambda *_, **__: tf_agent.LossInfo(
                counter.assign_add(1),  # pylint: disable=g-long-lambda
                ppo_agent.PPOLossInfo(*[zero] * 5)))

        train_op = agent.train(experience)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            # Assert that counter starts out at zero.
            counter_ = sess.run(counter)
            self.assertEqual(0, counter_)

            sess.run(train_op)

            # Assert that train_op ran increment_counter num_epochs times.
            counter_ = sess.run(counter)
            self.assertEqual(num_epochs, counter_)