def test_kl_constraint_loss_gradients(self): """Tests the gradients in the `_kl_constraint_loss` method.""" kl = jnp.array(1., jnp.float32) alpha = jnp.array(1., jnp.float32) _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False), _PROJECTION_OPERATOR) def alpha_loss_fn(alpha_): penalty = mpo_ops.LagrangePenalty( alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) _, alpha_loss, _ = mpo_ops.kl_constraint_loss( kl, penalty, _PROJECTION_OPERATOR) return alpha_loss alpha_gradients = jax.grad(alpha_loss_fn)(alpha) actual_alpha_gradients = _EPSILON_MEAN_BOUND - kl def kl_loss_fn(kl_): penalty = mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) kl_loss, _, _ = mpo_ops.kl_constraint_loss( kl_, penalty, _PROJECTION_OPERATOR) return kl_loss kl_gradients = jax.grad(kl_loss_fn)(kl) actual_kl_gradients = alpha self.assertAlmostEqual(kl_gradients, actual_kl_gradients) self.assertAlmostEqual(alpha_gradients, actual_alpha_gradients)
def test_kl_constraint_loss_stop_gradients(self): """Tests the stop gradients in the `kl_constraint_loss` function. The `alpha_loss` term should not affect the KL and the `kl` term should not affect `alpha`. """ kl = jnp.array(1., jnp.float32) alpha = jnp.array(1., jnp.float32) _, _, alpha = mpo_ops.kl_constraint_loss(kl, mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False), _PROJECTION_OPERATOR) def kl_loss_fn(alpha_): penalty = mpo_ops.LagrangePenalty( alpha=alpha_, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) kl_loss, _, _ = mpo_ops.kl_constraint_loss( kl, penalty, _PROJECTION_OPERATOR) return kl_loss kl_gradients = jax.grad(kl_loss_fn)(alpha) def alpha_loss_fn(kl_): penalty = mpo_ops.LagrangePenalty( alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) _, alpha_loss, _ = mpo_ops.kl_constraint_loss( kl_, penalty, _PROJECTION_OPERATOR) return alpha_loss alpha_gradients = jax.grad(alpha_loss_fn)(kl) # Test that there are no gradients of KL w.r.t alpha self.assertEqual(kl_gradients, 0.) # Test that there are no gradients of alpha w.r.t kl self.assertEqual(alpha_gradients, 0.)
def alpha_loss_fn(kl_): penalty = mpo_ops.LagrangePenalty(alpha=alpha, epsilon=_EPSILON_MEAN_BOUND, per_dimension=False) _, alpha_loss, _ = mpo_ops.kl_constraint_loss( kl_, penalty, _PROJECTION_OPERATOR) return alpha_loss