Exemplo n.º 1
0
  def test_kl_constraint_loss_gradients(self):
    """Tests the gradients in the `_kl_constraint_loss` method."""
    kl = jnp.array(1., jnp.float32)
    alpha = jnp.array(1., jnp.float32)
    _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty(
        alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False),
                                             _PROJECTION_OPERATOR)

    def alpha_loss_fn(alpha_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      _, alpha_loss, _ = mpo_ops.kl_constraint_loss(
          kl, penalty, _PROJECTION_OPERATOR)
      return alpha_loss
    alpha_gradients = jax.grad(alpha_loss_fn)(alpha)
    actual_alpha_gradients = _EPSILON_MEAN_BOUND - kl

    def kl_loss_fn(kl_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      kl_loss, _, _ = mpo_ops.kl_constraint_loss(
          kl_, penalty, _PROJECTION_OPERATOR)
      return kl_loss
    kl_gradients = jax.grad(kl_loss_fn)(kl)
    actual_kl_gradients = alpha

    self.assertAlmostEqual(kl_gradients, actual_kl_gradients)
    self.assertAlmostEqual(alpha_gradients, actual_alpha_gradients)
Exemplo n.º 2
0
  def test_kl_constraint_loss_stop_gradients(self):
    """Tests the stop gradients in the `kl_constraint_loss` function.

      The `alpha_loss` term should not affect the KL and the `kl` term should
      not affect `alpha`.
    """
    kl = jnp.array(1., jnp.float32)
    alpha = jnp.array(1., jnp.float32)
    _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty(
        alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False),
                                             _PROJECTION_OPERATOR)

    def kl_loss_fn(alpha_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      kl_loss, _, _ = mpo_ops.kl_constraint_loss(
          kl, penalty, _PROJECTION_OPERATOR)
      return kl_loss

    kl_gradients = jax.grad(kl_loss_fn)(alpha)

    def alpha_loss_fn(kl_):
      penalty = mpo_ops.LagrangePenalty(
          alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False)
      _, alpha_loss, _ = mpo_ops.kl_constraint_loss(
          kl_, penalty, _PROJECTION_OPERATOR)
      return alpha_loss
    alpha_gradients = jax.grad(alpha_loss_fn)(kl)

    # Test that there are no gradients of KL w.r.t alpha
    self.assertEqual(kl_gradients, 0.)

    # Test that there are no gradients of alpha w.r.t kl
    self.assertEqual(alpha_gradients, 0.)
Exemplo n.º 3
0
 def alpha_loss_fn(kl_):
     penalty = mpo_ops.LagrangePenalty(alpha=alpha,
                                       epsilon=_EPSILON_MEAN_BOUND,
                                       per_dimension=False)
     _, alpha_loss, _ = mpo_ops.kl_constraint_loss(
         kl_, penalty, _PROJECTION_OPERATOR)
     return alpha_loss