예제 #1
0
 def test_top_k_fraction(self, top_k_fraction, scaled_advantages,
                         expected_top_k_weights):
     """Test that only the top k fraction are used."""
     top_k_weights = mpo_ops.get_top_k_weights(
         top_k_fraction, jnp.ones_like(scaled_advantages),
         scaled_advantages)
     np.testing.assert_allclose(top_k_weights, expected_top_k_weights)
예제 #2
0
 def test_top_k_fraction_too_low(self):
   """Test if the top k fraction returns 0 advantages we raise an error."""
   with self.assertRaises(ValueError):
     mpo_ops.get_top_k_weights(0.01, jnp.ones((3, 2)), jnp.ones((3, 2)))