Exemplo n.º 1
0
def vmpo_e_step_without_restarting_or_importance_weights(advantages, **kwargs):
    restarting_weights = jnp.ones_like(advantages)
    importance_weights = jnp.ones_like(advantages)
    return mpo_ops.vmpo_compute_weights_and_temperature_loss(
        advantages=advantages,
        restarting_weights=restarting_weights,
        importance_weights=importance_weights,
        **kwargs)
Exemplo n.º 2
0
 def test_importance_weights(
     self, advantages, importance_weights, expected_temperature_loss):
   """Test that importance weights have the correct effect."""
   temperature_loss, _, _ = mpo_ops.vmpo_compute_weights_and_temperature_loss(
       advantages, np.ones_like(importance_weights), importance_weights,
       mpo_ops.LagrangePenalty(1.0, _EPSILON_BOUND),
       functools.partial(np.clip, a_min=1e-8, a_max=None), 1.0)
   self.assertAlmostEqual(
       temperature_loss, expected_temperature_loss, places=4)
Exemplo n.º 3
0
 def test_restarting_weights(
     self, advantages, restarting_weights, expected_temperature_loss):
   """Test that calculation is correct if restarting weight is set to 0."""
   temperature_loss, _, _ = mpo_ops.vmpo_compute_weights_and_temperature_loss(
       advantages, restarting_weights, np.ones_like(restarting_weights),
       mpo_ops.LagrangePenalty(1.0, _EPSILON_BOUND),
       functools.partial(np.clip, a_min=1e-8, a_max=None), 1.0)
   self.assertAlmostEqual(
       temperature_loss, expected_temperature_loss, places=4)