Exemplo n.º 1
0
  def _policy_loss(
      self, mean, logstd, old_mean, old_logstd, action, advantage, length):
    """Compute the policy loss composed of multiple components.

    1. The policy gradient loss is importance sampled from the data-collecting
       policy at the beginning of training.
    2. The second term is a KL penalty between the policy at the beginning of
       training and the current policy.
    3. Additionally, if this KL already changed more than twice the target
       amount, we activate a strong penalty discouraging further divergence.

    Args:
      mean: Sequences of action means of the current policy.
      logstd: Sequences of action log stddevs of the current policy.
      old_mean: Sequences of action means of the behavioral policy.
      old_logstd: Sequences of action log stddevs of the behavioral policy.
      action: Sequences of actions.
      advantage: Sequences of advantages.
      length: Batch of sequence lengths.

    Returns:
      Tuple of loss tensor and summary tensor.
    """
    with tf.name_scope('policy_loss'):
      entropy = utility.diag_normal_entropy(mean, logstd)
      kl = tf.reduce_mean(self._mask(utility.diag_normal_kl(
          old_mean, old_logstd, mean, logstd), length), 1)
      policy_gradient = tf.exp(
          utility.diag_normal_logpdf(mean, logstd, action) -
          utility.diag_normal_logpdf(old_mean, old_logstd, action))
      surrogate_loss = -tf.reduce_mean(self._mask(
          policy_gradient * tf.stop_gradient(advantage), length), 1)
      kl_penalty = self._penalty * kl
      cutoff_threshold = self._config.kl_target * self._config.kl_cutoff_factor
      cutoff_count = tf.reduce_sum(
          tf.cast(kl > cutoff_threshold, tf.int32))
      with tf.control_dependencies([tf.cond(
          cutoff_count > 0,
          lambda: tf.Print(0, [cutoff_count], 'kl cutoff! '), int)]):
        kl_cutoff = (
            self._config.kl_cutoff_coef *
            tf.cast(kl > cutoff_threshold, tf.float32) *
            (kl - cutoff_threshold) ** 2)
      policy_loss = surrogate_loss + kl_penalty + kl_cutoff
      summary = tf.summary.merge([
          tf.summary.histogram('entropy', entropy),
          tf.summary.histogram('kl', kl),
          tf.summary.histogram('surrogate_loss', surrogate_loss),
          tf.summary.histogram('kl_penalty', kl_penalty),
          tf.summary.histogram('kl_cutoff', kl_cutoff),
          tf.summary.histogram('kl_penalty_combined', kl_penalty + kl_cutoff),
          tf.summary.histogram('policy_loss', policy_loss),
          tf.summary.scalar('avg_surr_loss', tf.reduce_mean(surrogate_loss)),
          tf.summary.scalar('avg_kl_penalty', tf.reduce_mean(kl_penalty)),
          tf.summary.scalar('avg_policy_loss', tf.reduce_mean(policy_loss))])
      policy_loss = tf.reduce_mean(policy_loss, 0)
      return tf.check_numerics(policy_loss, 'policy_loss'), summary
Exemplo n.º 2
0
    def _adjust_penalty(self, observ, old_mean, old_logstd, length):
        """Adjust the KL policy between the behavioral and current policy.

    Compute how much the policy actually changed during the multiple
    update steps. Adjust the penalty strength for the next training phase if we
    overshot or undershot the target divergence too much.

    Args:
      observ: Sequences of observations.
      old_mean: Sequences of action means of the behavioral policy.
      old_logstd: Sequences of action log stddevs of the behavioral policy.
      length: Batch of sequence lengths.

    Returns:
      Summary tensor.
    """
        with tf.name_scope('adjust_penalty'):
            network = self._network(observ, length)
            assert_change = tf.assert_equal(tf.reduce_all(
                tf.equal(network.mean, old_mean)),
                                            False,
                                            message='policy should change')
            print_penalty = tf.Print(0, [self._penalty], 'current penalty: ')
            with tf.control_dependencies([assert_change, print_penalty]):
                kl_change = tf.reduce_mean(
                    self._mask(
                        utility.diag_normal_kl(old_mean, old_logstd,
                                               network.mean, network.logstd),
                        length))
                kl_change = tf.Print(kl_change, [kl_change], 'kl change: ')
                maybe_increase = tf.cond(
                    kl_change > 1.3 * self._config.kl_target,
                    # pylint: disable=g-long-lambda
                    lambda: tf.Print(self._penalty.assign(self._penalty * 1.5),
                                     [0], 'increase penalty '),
                    float)
                maybe_decrease = tf.cond(
                    kl_change < 0.7 * self._config.kl_target,
                    # pylint: disable=g-long-lambda
                    lambda: tf.Print(self._penalty.assign(self._penalty / 1.5),
                                     [0], 'decrease penalty '),
                    float)
            with tf.control_dependencies([maybe_increase, maybe_decrease]):
                return tf.summary.merge([
                    tf.summary.scalar('kl_change', kl_change),
                    tf.summary.scalar('penalty', self._penalty)
                ])