Esempio n. 1
0
    def actor_loss(self,
                   time_steps,
                   rb_actions=None,
                   weights=None,
                   q_combinator='min',
                   entropy_coef=1e-4):
        """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      rb_actions: Actions from the replay buffer. While not used in the main RCE
        method, we used these actions to train a behavior policy for the
        ablation experiment studying how to sample actions for the success
        examples.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      q_combinator: Whether to combine the two Q-functions by taking the 'min'
        (as in TD3) or the 'max'.
      entropy_coef: Coefficient for entropy regularization term. We found that
        1e-4 worked well for all environments.
    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)

            actions, log_pi = self._actions_and_log_probs(time_steps)
            target_input = (time_steps.observation, actions)

            target_q_values1, _ = self._critic_network_1(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values2, _ = self._critic_network_2(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            if q_combinator == 'min':
                target_q_values = tf.minimum(target_q_values1,
                                             target_q_values2)
            else:
                assert q_combinator == 'max'
                target_q_values = tf.maximum(target_q_values1,
                                             target_q_values2)
            if entropy_coef == 0:
                actor_loss = -target_q_values
            else:
                actor_loss = entropy_coef * log_pi - target_q_values
            if actor_loss.shape.rank > 1:
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(actor_loss,
                                           axis=range(1,
                                                      actor_loss.shape.rank))
            reg_loss = self._actor_network.losses if self._actor_network else None
            agg_loss = common.aggregate_losses(per_example_loss=actor_loss,
                                               sample_weight=weights,
                                               regularization_loss=reg_loss)
            actor_loss = agg_loss.total_loss
            self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                             target_q_values, time_steps)

            return actor_loss
Esempio n. 2
0
    def actor_loss(self,
                   time_steps: ts.TimeStep,
                   actions: types.Tensor,
                   weights: Optional[types.Tensor] = None,
                   training: Optional[bool] = True) -> types.Tensor:
        """Computes actor_loss equivalent to the SAC actor_loss.

    Uses behavioral cloning for the first `self._num_bc_steps` of training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether training should be applied.

    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)

            sampled_actions, sampled_log_pi = self._actions_and_log_probs(
                time_steps, training=training)

            # Behavioral cloning: train the policy to reproduce actions from
            # the dataset.
            if self.train_step_counter < self._num_bc_steps:
                distribution, _ = self._actor_network(time_steps.observation,
                                                      time_steps.step_type, ())
                actor_log_prob = distribution.log_prob(actions)
                actor_loss = tf.exp(
                    self._log_alpha) * sampled_log_pi - actor_log_prob
                target_q_values = tf.zeros(tf.shape(sampled_log_pi))
            else:
                target_input = (time_steps.observation, sampled_actions)
                target_q_values1, _ = self._critic_network_1(
                    target_input, time_steps.step_type, training=False)
                target_q_values2, _ = self._critic_network_2(
                    target_input, time_steps.step_type, training=False)
                target_q_values = tf.minimum(target_q_values1,
                                             target_q_values2)
                actor_loss = tf.exp(
                    self._log_alpha) * sampled_log_pi - target_q_values

            if actor_loss.shape.rank > 1:
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(actor_loss,
                                           axis=range(1,
                                                      actor_loss.shape.rank))
            reg_loss = self._actor_network.losses if self._actor_network else None
            agg_loss = common.aggregate_losses(per_example_loss=actor_loss,
                                               sample_weight=weights,
                                               regularization_loss=reg_loss)
            actor_loss = agg_loss.total_loss
            self._actor_loss_debug_summaries(actor_loss, sampled_actions,
                                             sampled_log_pi, target_q_values,
                                             time_steps)

            return actor_loss
Esempio n. 3
0
    def alpha_loss(self, time_steps, weights=None):
        """Computes the alpha_loss for EC-SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      alpha_loss: A scalar alpha loss.
    """
        with tf.name_scope('alpha_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)

            unused_actions, log_pi = self._actions_and_log_probs(time_steps)
            entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy)
            alpha_loss = (self._log_alpha * entropy_diff)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                alpha_loss = tf.reduce_sum(input_tensor=alpha_loss, axis=1)
            else:
                alpha_loss = tf.expand_dims(alpha_loss, 0)

            agg_loss = common.aggregate_losses(per_example_loss=alpha_loss,
                                               sample_weight=weights)
            alpha_loss = agg_loss.total_loss

            self._alpha_loss_debug_summaries(alpha_loss, entropy_diff)

            return alpha_loss
 def test_aggregate_losses_with_time_dim_and_float_weights(self):
     per_example_loss = tf.constant([[4., 2., 3.], [1, 1, 1]])
     sample_weights = 0.5
     aggregated_losses = common.aggregate_losses(per_example_loss,
                                                 sample_weights)
     expected_per_example_loss = 0.5 * (4 + 2 + 3 + 1 + 1 + 1) / 6
     self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss),
                            expected_per_example_loss)
 def test_aggregate_losses_three_dimensions(self):
     per_example_loss = tf.constant([[[4., 2., 3.], [1, 1, 1]],
                                     [[8., 4., 6.], [2, 2, 2]]])
     aggregated_losses = common.aggregate_losses(per_example_loss)
     expected_per_example_loss = (4 + 2 + 3 + 1 + 1 + 1 + 8 + 4 + 6 + 2 +
                                  2 + 2) / 12
     self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss),
                            expected_per_example_loss)
 def test_aggregate_losses_with_time_dim_and_weights_with_batch_dim(self):
     per_example_loss = tf.constant([[4., 2., 3.], [1, 1, 1]])
     sample_weights = tf.constant([
         1.,
         0.,
     ])
     aggregated_losses = common.aggregate_losses(per_example_loss,
                                                 sample_weights)
     expected_per_example_loss = (4 + 2 + 3) / 6
     self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss),
                            expected_per_example_loss)
    def _loss(self, experience, weights=None, training: bool = False):
        experience = self._as_trajectory(experience)

        per_example_loss = self._bc_loss_fn(experience, training=training)
        aggregated_losses = common.aggregate_losses(
            per_example_loss=per_example_loss,
            sample_weight=weights,
            regularization_loss=self._cloning_network.losses)

        return tf_agent.LossInfo(
            loss=aggregated_losses.total_loss,
            extra=BehavioralCloningLossInfo(per_example_loss))
  def actor_loss(self,
                 time_steps,
                 actions,
                 weights = None,
                 ce_loss = False):
    """Computes the actor_loss for C-learning training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      ce_loss: (bool) Whether to update the actor using the cross entropy loss,
        which corresponds to using the log C-value. The default actor loss
        differs by not including the log. Empirically we observed no difference.

    Returns:
      actor_loss: A scalar actor loss.
    """
    with tf.name_scope('actor_loss'):
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)

      sampled_actions, log_pi = self._actions_and_log_probs(time_steps)
      target_input = (time_steps.observation, sampled_actions)
      target_q_values1, _ = self._critic_network_1(
          target_input, time_steps.step_type, training=False)
      target_q_values2, _ = self._critic_network_2(
          target_input, time_steps.step_type, training=False)
      target_q_values = tf.minimum(target_q_values1, target_q_values2)
      if ce_loss:
        actor_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(target_q_values), target_q_values)
      else:
        actor_loss = -1.0 * target_q_values

      if actor_loss.shape.rank > 1:
        # Sum over the time dimension.
        actor_loss = tf.reduce_sum(
            actor_loss, axis=range(1, actor_loss.shape.rank))
      reg_loss = self._actor_network.losses if self._actor_network else None
      agg_loss = common.aggregate_losses(
          per_example_loss=actor_loss,
          sample_weight=weights,
          regularization_loss=reg_loss)
      actor_loss = agg_loss.total_loss
      self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                       target_q_values, time_steps)

      return actor_loss
Esempio n. 9
0
    def _train(self, experience, weights=None):
        with tf.GradientTape() as tape:
            per_example_loss = self._loss_fn(experience)

            aggregated_losses = common.aggregate_losses(
                per_example_loss=per_example_loss,
                sample_weight=weights,
                regularization_loss=self._cloning_network.losses)

        self._apply_loss(aggregated_losses,
                         self._cloning_network.trainable_weights, tape,
                         self._optimizer)
        self.train_step_counter.assign_add(1)
        return tf_agent.LossInfo(aggregated_losses.total_loss,
                                 BehavioralCloningLossInfo(per_example_loss))
Esempio n. 10
0
 def test_aggregate_4d_losses_and_2d_weights(self):
   per_example_loss = tf.constant([[[[4., 2., 3.], [1, 1, 1]],
                                    [[8., 4., 6.], [2, 2, 2]]],
                                   [[[4., 2., 3.], [1, 1, 1]],
                                    [[8., 4., 6.], [2, 2, 2]]]])  # 2x2x2x3
   sample_weights = tf.constant([[
       1.,
       0.,
   ], [
       0.,
       0.,
   ]])
   aggregated_losses = common.aggregate_losses(per_example_loss,
                                               sample_weights)
   expected_per_example_loss = (4 + 2 + 3 + 1 + 1 + 1) / 24
   self.assertAlmostEqual(
       self.evaluate(aggregated_losses.total_loss), expected_per_example_loss)
Esempio n. 11
0
    def actor_loss(self, time_steps, weights=None):
        """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)

            actions, log_pi = self._actions_and_log_probs(time_steps)
            target_input = (time_steps.observation, actions)
            target_q_values1, _ = self._critic_network_1(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values2, _ = self._critic_network_2(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            # Stop gradients to avoid updates to shared layers between critic and
            # actor. They could still be updated through the actor if desired, but we
            # do not want gradients to flow to shared variables throught the critic.
            target_q_values = tf.stop_gradient(target_q_values)

            actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1)
            reg_loss = self._actor_network.losses if self._actor_network else None
            agg_loss = common.aggregate_losses(per_example_loss=actor_loss,
                                               sample_weight=weights,
                                               regularization_loss=reg_loss)
            actor_loss = agg_loss.total_loss
            self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                             target_q_values, time_steps)

            return actor_loss
Esempio n. 12
0
 def _add_auxiliary_losses(self, transition, weights, losses_dict):
     """Computes auxiliary losses, updating losses_dict in place."""
     total_auxiliary_loss = 0
     if self._auxiliary_loss_fns is not None:
         for auxiliary_loss_fn in self._auxiliary_loss_fns:
             auxiliary_loss, auxiliary_reg_loss = auxiliary_loss_fn(
                 network=self._q_network, transition=transition)
             agg_auxiliary_loss = common.aggregate_losses(
                 per_example_loss=auxiliary_loss,
                 sample_weight=weights,
                 regularization_loss=auxiliary_reg_loss)
             total_auxiliary_loss += agg_auxiliary_loss.total_loss
             losses_dict.update({
                 'auxiliary_loss_{}'.format(auxiliary_loss_fn.__name__):
                 agg_auxiliary_loss.weighted,
                 'auxiliary_reg_loss_{}'.format(auxiliary_loss_fn.__name__):
                 agg_auxiliary_loss.regularization,
             })
     return total_auxiliary_loss
Esempio n. 13
0
    def actor_loss(self,
                   time_steps: ts.TimeStep,
                   weights: Optional[types.Tensor] = None,
                   training: Optional[bool] = True) -> types.Tensor:
        """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether training should be applied.

    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)

            actions, log_pi = self._actions_and_log_probs(time_steps,
                                                          training=training)
            target_input = (time_steps.observation, actions)
            # We do not update critic during actor loss.
            target_q_values1, _ = self._critic_network_1(
                target_input, step_type=time_steps.step_type, training=False)
            target_q_values2, _ = self._critic_network_2(
                target_input, step_type=time_steps.step_type, training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values
            if actor_loss.shape.rank > 1:
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(actor_loss,
                                           axis=range(1,
                                                      actor_loss.shape.rank))
            reg_loss = self._actor_network.losses if self._actor_network else None
            agg_loss = common.aggregate_losses(per_example_loss=actor_loss,
                                               sample_weight=weights,
                                               regularization_loss=reg_loss)
            actor_loss = agg_loss.total_loss
            self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                             target_q_values, time_steps)

            return actor_loss
Esempio n. 14
0
    def alpha_loss(self,
                   time_steps: ts.TimeStep,
                   weights: Optional[types.Tensor] = None,
                   training: bool = False) -> types.Tensor:
        """Computes the alpha_loss for EC-SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used during training.

    Returns:
      alpha_loss: A scalar alpha loss.
    """
        with tf.name_scope('alpha_loss'):
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)

            # We do not update actor during alpha loss.
            unused_actions, log_pi = self._actions_and_log_probs(
                time_steps, training=False)
            entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy)
            if self._use_log_alpha_in_alpha_loss:
                alpha_loss = (self._log_alpha * entropy_diff)
            else:
                alpha_loss = (tf.exp(self._log_alpha) * entropy_diff)

            if alpha_loss.shape.rank > 1:
                # Sum over the time dimension.
                alpha_loss = tf.reduce_sum(alpha_loss,
                                           axis=range(1,
                                                      alpha_loss.shape.rank))

            agg_loss = common.aggregate_losses(per_example_loss=alpha_loss,
                                               sample_weight=weights)
            alpha_loss = agg_loss.total_loss

            self._alpha_loss_debug_summaries(alpha_loss, entropy_diff)

            return alpha_loss
Esempio n. 15
0
    def actor_loss(self,
                   time_steps: ts.TimeStep,
                   weights: Optional[types.Tensor] = None) -> types.Tensor:
        """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)

            actions, log_pi = self._actions_and_log_probs(time_steps)
            target_input = (time_steps.observation, actions)
            target_q_values1, _ = self._critic_network_1(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values2, _ = self._critic_network_2(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values
            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1)
            reg_loss = self._actor_network.losses if self._actor_network else None
            agg_loss = common.aggregate_losses(per_example_loss=actor_loss,
                                               sample_weight=weights,
                                               regularization_loss=reg_loss)
            actor_loss = agg_loss.total_loss
            self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                             target_q_values, time_steps)

            return actor_loss
Esempio n. 16
0
    def actor_loss(self,
                   time_steps: ts.TimeStep,
                   weights: Optional[types.Tensor] = None,
                   training: bool = False) -> types.Tensor:
        """Computes the actor_loss for TD3 training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.
      # TODO(b/124383618): Add an action norm regularizer.
    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            actions, _ = self._actor_network(time_steps.observation,
                                             time_steps.step_type,
                                             training=training)

            q_values, _ = self._critic_network_1(
                (time_steps.observation, actions),
                time_steps.step_type,
                training=False)
            actor_loss = -q_values
            # Sum over the time dimension.
            if actor_loss.shape.rank > 1:
                actor_loss = tf.reduce_sum(actor_loss,
                                           axis=range(1,
                                                      actor_loss.shape.rank))
            actor_loss = common.aggregate_losses(
                per_example_loss=actor_loss, sample_weight=weights).total_loss

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar(name='actor_loss',
                                            data=actor_loss,
                                            step=self.train_step_counter)

        return actor_loss
Esempio n. 17
0
  def _loss(self,
            experience,
            td_errors_loss_fn=common.element_wise_huber_loss,
            gamma=1.0,
            reward_scale_factor=1.0,
            weights=None,
            training=False):
    """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory` or
        `Transition`. The structure of `experience` must match that of
        `self.collect_policy.step_spec`.

        If a `Trajectory`, all tensors in `experience` must be shaped
        `[B, T, ...]` where `T` must be equal to `self.train_sequence_length`
        if that property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    transition = self._as_transition(experience)
    time_steps, policy_steps, next_time_steps = transition
    actions = policy_steps.action

    with tf.name_scope('loss'):
      q_values = self._compute_q_values(time_steps, actions, training=training)

      next_q_values = self._compute_next_q_values(
          next_time_steps, policy_steps.info)

      # This applies to any value of n_step_update and also in the RNN-DQN case.
      # In the RNN-DQN case, inputs and outputs contain a time dimension.
      td_targets = compute_td_targets(
          next_q_values,
          rewards=reward_scale_factor * next_time_steps.reward,
          discounts=gamma * next_time_steps.discount)

      valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
      td_error = valid_mask * (td_targets - q_values)

      td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Do a sum over the time dimension.
        td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

      # Aggregate across the elements of the batch and add regularization loss.
      # Note: We use an element wise loss above to ensure each element is always
      #   weighted by 1/N where N is the batch size, even when some of the
      #   weights are zero due to boundary transitions. Weighting by 1/K where K
      #   is the actual number of non-zero weight would artificially increase
      #   their contribution in the loss. Think about what would happen as
      #   the number of boundary samples increases.

      agg_loss = common.aggregate_losses(
          per_example_loss=td_loss,
          sample_weight=weights,
          regularization_loss=self._q_network.losses)
      total_loss = agg_loss.total_loss

      losses_dict = {'td_loss': agg_loss.weighted,
                     'reg_loss': agg_loss.regularization,
                     'total_loss': total_loss}

      common.summarize_scalar_dict(losses_dict,
                                   step=self.train_step_counter,
                                   name_scope='Losses/')

      if self._summarize_grads_and_vars:
        with tf.name_scope('Variables/'):
          for var in self._q_network.trainable_weights:
            tf.compat.v2.summary.histogram(
                name=var.name.replace(':', '_'),
                data=var,
                step=self.train_step_counter)

      if self._debug_summaries:
        diff_q_values = q_values - next_q_values
        common.generate_tensor_summaries('td_error', td_error,
                                         self.train_step_counter)
        common.generate_tensor_summaries('td_loss', td_loss,
                                         self.train_step_counter)
        common.generate_tensor_summaries('q_values', q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('next_q_values', next_q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('diff_q_values', diff_q_values,
                                         self.train_step_counter)

      return tf_agent.LossInfo(total_loss, DqnLossInfo(td_loss=td_loss,
                                                       td_error=td_error))
    def _loss_h(self,
                experience,
                td_errors_loss_fn=common.element_wise_huber_loss,
                gamma=1.0,
                reward_scale_factor=1.0,
                weights=None,
                training=False):

        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        valid_mask = tf.cast(~time_steps.is_last(), tf.float32)

        with tf.name_scope('loss'):
            # q_values is already gathered by actions
            h_values = self._compute_h_values(time_steps,
                                              actions,
                                              training=training)

            multi_dim_actions = self._action_spec.shape.rank > 0

            next_q_all_values = self._compute_next_all_q_values(
                next_time_steps, policy_steps.info)

            next_h_all_values = self._compute_next_all_h_values(
                next_time_steps, policy_steps.info)

            next_h_actions = tf.argmax(next_h_all_values, axis=1)

            # next_h_values here is used only for logging
            next_h_values = self._compute_next_h_values(
                next_time_steps, policy_steps.info)

            # next_q_values refer to Q(r,s') in Eqs.(4),(5)
            next_q_values = common.index_with_actions(
                next_q_all_values, tf.cast(next_h_actions, dtype=tf.int32),
                multi_dim_actions)

            h_target_all_values = self._compute_next_all_h_values(
                time_steps, policy_steps.info)

            h_target_values = common.index_with_actions(
                h_target_all_values, tf.cast(actions, dtype=tf.int32),
                multi_dim_actions)

            td_targets = compute_momentum_td_targets(
                q_target_values=next_q_values,
                h_target_values=h_target_values,
                beta=self.beta())

            td_error = valid_mask * (td_targets - h_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, h_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._h_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_h_values = h_values - next_h_values
                common.generate_tensor_summaries('td_error_h', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss_h', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('h_values', h_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_h_values',
                                                 next_h_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_h_values',
                                                 diff_h_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Esempio n. 19
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_sequence_length` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('loss'):
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_q_values(
                next_time_steps, policy_steps.info)

            if self._n_step_update == 1:
                # Special case for n = 1 to avoid a loss of performance.
                td_targets = compute_td_targets(
                    next_q_values,
                    rewards=reward_scale_factor * next_time_steps.reward,
                    discounts=gamma * next_time_steps.discount)
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                td_targets = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=next_q_values,
                    time_major=False,
                    provide_all_returns=False)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
  def critic_loss(self,
                  time_steps,
                  actions,
                  next_time_steps,
                  augmented_obs,
                  augmented_next_obs,
                  td_errors_loss_fn,
                  gamma=1.0,
                  reward_scale_factor=1.0,
                  weights=None,
                  training=False):
    """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      augmented_obs: List of observations.
      augmented_next_obs: List of next_observations.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
    with tf.name_scope('critic_loss'):
      nest_utils.assert_same_structure(actions, self.action_spec)
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)

      td_targets = self._compute_td_targets(next_time_steps,
                                            reward_scale_factor, gamma)

      # Compute td_targets with augmentations.
      for i in range(self._num_augmentations - 1):
        augmented_next_time_steps = next_time_steps._replace(
            observation=augmented_next_obs[i])

        augmented_td_targets = self._compute_td_targets(
            augmented_next_time_steps, reward_scale_factor, gamma)

        td_targets = td_targets + augmented_td_targets

      # Average td_target estimation over augmentations.
      if self._num_augmentations > 1:
        td_targets = td_targets / self._num_augmentations

      pred_td_targets1, pred_td_targets2, critic_loss = (
          self._compute_prediction_critic_loss(
              (time_steps.observation, actions), td_targets, time_steps,
              training, td_errors_loss_fn))

      # Add Q Augmentations to the critic loss.
      for i in range(self._num_augmentations - 1):
        augmented_time_steps = time_steps._replace(observation=augmented_obs[i])
        _, _, loss = (
            self._compute_prediction_critic_loss(
                (augmented_time_steps.observation, actions), td_targets,
                augmented_time_steps, training, td_errors_loss_fn))
        critic_loss = critic_loss + loss

      agg_loss = common.aggregate_losses(
          per_example_loss=critic_loss,
          sample_weight=weights,
          regularization_loss=(self._critic_network_1.losses +
                               self._critic_network_2.losses))
      critic_loss = agg_loss.total_loss

      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                        pred_td_targets2)

      return critic_loss
Esempio n. 21
0
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
      training: Whether the loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else network_utils.BatchSquash(rank))

            network_observation = time_steps.observation

            if self._observation_and_action_constraint_splitter is not None:
                network_observation, _ = (
                    self._observation_and_action_constraint_splitter(
                        network_observation))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(network_observation,
                                          time_steps.step_type,
                                          training=training)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            next_q_distribution = self._next_q_distribution(next_time_steps)

            if actions.shape.rank > 1:
                actions = tf.squeeze(actions,
                                     list(range(1, actions.shape.rank)))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.rank == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = tf.expand_dims(discount, -1)
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.rank == 1:
                    # See the explanation above.
                    reward = tf.expand_dims(reward, -1)
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                discounted_returns = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False,
                    provide_all_returns=False)

                # Convert discounted_returns from [batch_size] to [batch_size, 1]
                discounted_returns = tf.expand_dims(discounted_returns, -1)

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = tf.expand_dims(final_value_discount, -1)

                # Save the values of discounted_returns and final_value_discount in
                # order to check them in unit tests.
                self._discounted_returns = discounted_returns
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_returns,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(batch_size)
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.stack([indices, actions], axis=-1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_sum(
                    tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                    axis=1)
            else:
                critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                    labels=target_distribution, logits=chosen_action_logits)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'critic_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
Esempio n. 22
0
    def _critic_loss_with_optional_entropy_term(
            self,
            time_steps: ts.TimeStep,
            actions: types.Tensor,
            next_time_steps: ts.TimeStep,
            td_errors_loss_fn: types.LossFn,
            gamma: types.Float = 1.0,
            reward_scale_factor: types.Float = 1.0,
            weights: Optional[types.Tensor] = None,
            training: bool = False) -> types.Tensor:
        r"""Computes the critic loss for CQL-SAC training.

    The original SAC critic loss is:
    ```
    (q(s, a) - (r(s, a) + \gamma q(s', a') - \gamma \alpha \log \pi(a'|s')))^2
    ```

    The CQL-SAC critic loss makes the entropy term optional.
    CQL may value unseen actions higher since it lower-bounds the value of
    seen actions. This makes the policy entropy potentially redundant in the
    target term, since it will further enhance unseen actions' effects.

    If self._include_critic_entropy_term is False, this loss equation becomes:
    ```
    (q(s, a) - (r(s, a) + \gamma q(s', a')))^2
    ```

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            # We do not update actor or target networks in critic loss.
            next_actions, next_log_pis = self._actions_and_log_probs(
                next_time_steps, training=False)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)

            if self._include_critic_entropy_term:
                target_q_values -= (tf.exp(self._log_alpha) * next_log_pis)

            reward = next_time_steps.reward
            if self._reward_noise_variance > 0:
                reward_noise = tf.random.normal(
                    tf.shape(reward),
                    0.0,
                    self._reward_noise_variance,
                    seed=self._reward_seed_stream())
                reward += reward_noise

            td_targets = tf.stop_gradient(reward_scale_factor * reward +
                                          gamma * next_time_steps.discount *
                                          target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss,
                                            axis=range(1,
                                                       critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss

            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2)

            return critic_loss
Esempio n. 23
0
  def critic_loss(self,
                  time_steps,
                  expert_experience,
                  actions,
                  next_time_steps,
                  future_time_steps,
                  td_errors_loss_fn,
                  gamma = 1.0,
                  reward_scale_factor = 1.0,
                  weights = None,
                  training = False,
                  loss_name='c',
                  use_done=False,
                  q_combinator='min'):
    """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      expert_experience: An array of success examples.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      future_time_steps: A batch of future timesteps, used for n-step returns.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.
      loss_name: Which loss function to use. Use 'c' for RCE and 'q' for SQIL.
      use_done: Whether to use the terminal flag from the environment in the
        Bellman backup. We found that omitting it led to better results.
      q_combinator: Whether to combine the two Q-functions by taking the 'min'
        (as in TD3) or the 'max'.

    Returns:
      critic_loss: A scalar critic loss.
    """
    assert weights is None
    with tf.name_scope('critic_loss'):
      nest_utils.assert_same_structure(actions, self.action_spec)
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)

      next_actions, _ = self._actions_and_log_probs(next_time_steps)
      target_input = (next_time_steps.observation, next_actions)
      target_q_values1, unused_network_state1 = self._target_critic_network_1(
          target_input, next_time_steps.step_type, training=False)
      target_q_values2, unused_network_state2 = self._target_critic_network_2(
          target_input, next_time_steps.step_type, training=False)
      if self._n_step is not None:
        future_actions, _ = self._actions_and_log_probs(future_time_steps)
        future_input = (future_time_steps.observation, future_actions)
        future_q_values1, _ = self._target_critic_network_1(
            future_input, future_time_steps.step_type, training=False)
        future_q_values2, _ = self._target_critic_network_2(
            future_input, future_time_steps.step_type, training=False)

        gamma_n = gamma**self._n_step  # Discount for n-step returns
        target_q_values1 = (target_q_values1 + gamma_n * future_q_values1) / 2.0
        target_q_values2 = (target_q_values2 + gamma_n * future_q_values2) / 2.0

      if q_combinator == 'min':
        target_q_values = tf.minimum(target_q_values1, target_q_values2)
      else:
        assert q_combinator == 'max'
        target_q_values = tf.maximum(target_q_values1, target_q_values2)

      batch_size = time_steps.observation.shape[0]
      if loss_name == 'q':
        if use_done:
          td_targets = gamma * next_time_steps.discount * target_q_values
        else:
          td_targets = gamma * target_q_values
      else:
        assert loss_name == 'c'
        w = target_q_values / (1 - target_q_values)
        td_targets = gamma * w / (gamma * w + 1)
        if use_done:
          td_targets = next_time_steps.discount * td_targets
        weights = tf.concat([1 + gamma * w, (1 - gamma) * tf.ones(batch_size)],
                            axis=0)

      td_targets = tf.stop_gradient(td_targets)
      td_targets = tf.concat([td_targets, tf.ones(batch_size)], axis=0)

      # Note that the actions only depend on the observations. We create the
      # expert_time_steps object simply to make this look like a time step
      # object.
      expert_time_steps = time_steps._replace(observation=expert_experience)
      if self._use_behavior_policy:
        policy_state = self._train_policy.get_initial_state(batch_size)
        action_distribution = self._behavior_policy.distribution(
            time_steps, policy_state=policy_state).action
        # Sample actions and log_pis from transformed distribution.
        expert_actions = tf.nest.map_structure(lambda d: d.sample(),
                                               action_distribution)
      else:
        expert_actions, _ = self._actions_and_log_probs(expert_time_steps)

      observation = time_steps.observation
      pred_input = (tf.concat([observation, expert_experience], axis=0),
                    tf.concat([actions, expert_actions], axis=0))

      pred_td_targets1, _ = self._critic_network_1(
          pred_input, time_steps.step_type, training=training)
      pred_td_targets2, _ = self._critic_network_2(
          pred_input, time_steps.step_type, training=training)

      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                        pred_td_targets2)

      critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
      critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
      critic_loss = critic_loss1 + critic_loss2

      if critic_loss.shape.rank > 1:
        # Sum over the time dimension.
        critic_loss = tf.reduce_sum(
            critic_loss, axis=range(1, critic_loss.shape.rank))

      agg_loss = common.aggregate_losses(
          per_example_loss=critic_loss,
          sample_weight=weights,
          regularization_loss=(self._critic_network_1.losses +
                               self._critic_network_2.losses))
      critic_loss = agg_loss.total_loss

      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                        pred_td_targets2)

      return critic_loss
Esempio n. 24
0
    def critic_loss(
        self,
        time_steps,
        actions,
        next_time_steps,
        td_errors_loss_fn,
        gamma=1.0,
        weights=None,
        training=False,
        w_clipping=None,
        self_normalized=False,
        lambda_fix=False,
    ):
        """Computes the critic loss for C-learning training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.
      w_clipping: Maximum value used for clipping the weights. Use -1 to do no
        clipping; use None to use the recommended value of 1 / (1 - gamma).
      self_normalized: Whether to normalize the weights to the average is 1.
        Empirically this usually hurts performance.
      lambda_fix: Whether to include the adjustment when using future positives.
        Empirically this has little effect.

    Returns:
      critic_loss: A scalar critic loss.
    """
        del weights
        if w_clipping is None:
            w_clipping = 1 / (1 - gamma)
        rfp = gin.query_parameter('goal_fn.relabel_future_prob')
        rnp = gin.query_parameter('goal_fn.relabel_next_prob')
        assert rfp + rnp == 0.5
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            next_actions, _ = self._actions_and_log_probs(next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)

            w = tf.stop_gradient(target_q_values / (1 - target_q_values))
            if w_clipping >= 0:
                w = tf.clip_by_value(w, 0, w_clipping)
            tf.debugging.assert_all_finite(w,
                                           'Not all elements of w are finite')
            if self_normalized:
                w = w / tf.reduce_mean(w)

            batch_size = nest_utils.get_outer_shape(time_steps,
                                                    self._time_step_spec)[0]
            half_batch = batch_size // 2
            float_batch_size = tf.cast(batch_size, float)
            num_next = tf.cast(tf.round(float_batch_size * rnp), tf.int32)
            num_future = tf.cast(tf.round(float_batch_size * rfp), tf.int32)
            if lambda_fix:
                lambda_coef = 2 * rnp
                weights = tf.concat([
                    tf.fill((num_next, ), (1 - gamma)),
                    tf.fill((num_future, ), 1.0),
                    (1 + lambda_coef * gamma * w)[half_batch:]
                ],
                                    axis=0)
            else:
                weights = tf.concat([
                    tf.fill((num_next, ), (1 - gamma)),
                    tf.fill((num_future, ), 1.0), (1 + gamma * w)[half_batch:]
                ],
                                    axis=0)

            # Note that we assume that episodes never terminate. If they do, then
            # we need to include next_time_steps.discount in the (negative) TD target.
            # We exclude the termination here so that we can use termination to
            # indicate task success during evaluation. In the evaluation setting,
            # task success depends on the task, but we don't want the termination
            # here to depend on the task. Hence, we ignored it.
            if lambda_fix:
                lambda_coef = 2 * rnp
                y = lambda_coef * gamma * w / (1 + lambda_coef * gamma * w)
            else:
                y = gamma * w / (1 + gamma * w)
            td_targets = tf.stop_gradient(next_time_steps.reward +
                                          (1 - next_time_steps.reward) * y)
            if rfp > 0:
                td_targets = tf.concat(
                    [tf.ones(half_batch), td_targets[half_batch:]], axis=0)

            observation = time_steps.observation
            pred_input = (observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=training)

            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss,
                                            axis=range(1,
                                                       critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss
            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2, weights)

            return critic_loss
Esempio n. 25
0
    def critic_loss(self,
                    time_steps: ts.TimeStep,
                    actions: types.Tensor,
                    next_time_steps: ts.TimeStep,
                    td_errors_loss_fn: types.LossFn,
                    gamma: types.Float = 1.0,
                    reward_scale_factor: types.Float = 1.0,
                    weights: Optional[types.Tensor] = None,
                    training: bool = False) -> types.Tensor:
        """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            next_actions, next_log_pis = self._actions_and_log_probs(
                next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = (tf.minimum(target_q_values1, target_q_values2) -
                               tf.exp(self._log_alpha) * next_log_pis)

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss,
                                            axis=range(1,
                                                       critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss

            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2)

            return critic_loss
Esempio n. 26
0
 def test_aggregate_losses_with_time_dimension(self):
     per_example_loss = tf.constant([[4., 2., 3.], [1, 1, 1]])
     aggregated_losses = common.aggregate_losses(per_example_loss)
     expected_per_example_loss = (4 + 2 + 3 + 1 + 1 + 1) / 6
     self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss),
                            expected_per_example_loss)
Esempio n. 27
0
 def test_aggregate_losses_without_time_dimension_with_weights(self):
     per_example_loss = tf.constant([4., 2., 3.])
     sample_weights = tf.constant([1., 1., 0.])
     aggregated_losses = common.aggregate_losses(per_example_loss,
                                                 sample_weights)
     self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss), 2)
Esempio n. 28
0
    def critic_no_entropy_loss(self,
                               time_steps,
                               actions,
                               next_time_steps,
                               td_errors_loss_fn,
                               gamma=1.0,
                               reward_scale_factor=1.0,
                               weights=None,
                               training=False):
        """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_no_entropy_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            next_actions, _ = self._actions_and_log_probs(next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_no_entropy_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_no_entropy_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = tf.minimum(
                target_q_values1, target_q_values2
            )  # entropy has been removed from the target critic function

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_no_entropy_1(
                pred_input, time_steps.step_type, training=training)
            pred_td_targets2, _ = self._critic_network_no_entropy_2(
                pred_input, time_steps.step_type, training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_no_entropy_1.losses +
                                     self._critic_network_no_entropy_2.losses))
            critic_no_entropy_loss = agg_loss.total_loss

            self._critic_no_entropy_loss_debug_summaries(
                td_targets, pred_td_targets1, pred_td_targets2)

            return critic_no_entropy_loss
Esempio n. 29
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):

        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        valid_mask = tf.cast(~time_steps.is_last(), tf.float32)

        with tf.name_scope('loss'):
            # q_values is already gathered by actions
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_all_q_values(
                next_time_steps, policy_steps.info)

            q_target_values = self._compute_next_all_q_values(
                time_steps, policy_steps.info)

            # This applies to any value of n_step_update and also in the RNN-DQN case.
            # In the RNN-DQN case, inputs and outputs contain a time dimension.
            #td_targets = compute_td_targets(
            #    next_q_values,
            #    rewards=reward_scale_factor * next_time_steps.reward,
            #    discounts=gamma * next_time_steps.discount)

            td_targets = compute_munchausen_td_targets(
                next_q_values=next_q_values,
                q_target_values=q_target_values,
                actions=actions,
                rewards=reward_scale_factor * next_time_steps.reward,
                discounts=gamma * next_time_steps.discount,
                multi_dim_actions=self._action_spec.shape.rank > 0,
                alpha=self.alpha,
                entropy_tau=self.entropy_tau)

            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Esempio n. 30
0
    def _loss(self, experience, weights=None):
        """Computes loss for behavioral cloning.

    Args:
      experience: A `Trajectory` containing experience.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.

    Returns:
      loss: A `LossInfo` struct.

    Raises:
      ValueError:
        If the number of actions is greater than 1.
    """
        with tf.name_scope('loss'):
            if self._nested_actions:
                actions = experience.action
            else:
                actions = tf.nest.flatten(experience.action)[0]

            batch_size = (tf.compat.dimension_value(
                experience.step_type.shape[0])
                          or tf.shape(experience.step_type)[0])
            logits, _ = self._cloning_network(
                experience.observation,
                experience.step_type,
                training=True,
                network_state=self._cloning_network.get_initial_state(
                    batch_size))

            error = self._loss_fn(logits, actions)
            error_dtype = tf.nest.flatten(error)[0].dtype
            boundary_weights = tf.cast(~experience.is_boundary(), error_dtype)
            error *= boundary_weights

            if nest_utils.is_batched_nested_tensors(experience.action,
                                                    self.action_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                error = tf.reduce_sum(input_tensor=error, axis=1)

            # Average across the elements of the batch.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=error,
                sample_weight=weights,
                regularization_loss=self._cloning_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._cloning_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                common.generate_tensor_summaries('errors', error,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(total_loss,
                                     BehavioralCloningLossInfo(loss=error))