def test_top_k_fraction(self, top_k_fraction, scaled_advantages, expected_top_k_weights): """Test that only the top k fraction are used.""" top_k_weights = mpo_ops.get_top_k_weights( top_k_fraction, jnp.ones_like(scaled_advantages), scaled_advantages) np.testing.assert_allclose(top_k_weights, expected_top_k_weights)
def test_top_k_fraction_too_low(self): """Test if the top k fraction returns 0 advantages we raise an error.""" with self.assertRaises(ValueError): mpo_ops.get_top_k_weights(0.01, jnp.ones((3, 2)), jnp.ones((3, 2)))