def test_safe_epsilon_softmax_equivalence(self, variant): distrib = distributions.safe_epsilon_softmax(epsilon=0.1, temperature=10.) softmax = variant(distrib.probs) # Test softmax output in batch. actual = softmax(self.logits) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
def test_safe_epsilon_softmax_equivalence(self, variant): distrib = distributions.safe_epsilon_softmax(epsilon=self.epsilon, temperature=0) probs_fn = variant(distrib.probs) # Test greedy output in batch. actual = probs_fn(self.preferences) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4) logprob_fn = variant(distrib.logprob) # Test greedy output in batch. actual = logprob_fn(self.samples, self.preferences) np.testing.assert_allclose(self.expected_logprob, actual, atol=1e-4) sample_fn = variant(distrib.sample) # Optionally convert to device array. key = np.array([1, 2], dtype=np.uint32) actions = sample_fn(key, self.preferences) # test just the shape self.assertEqual(actions.shape, (2, ))