def test_softmax_probs_batch(self, variant): """Tests for a full batch.""" distrib = distributions.epsilon_softmax(epsilon=0.1, temperature=10.) softmax = variant(distrib.probs) # Test softmax output in batch. actual = softmax(self.logits) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
def test_softmax_probs(self, variant): """Tests for a single element.""" distrib = distributions.epsilon_softmax(epsilon=0.1, temperature=10.) softmax = variant(distrib.probs) # For each element in the batch. for logits, expected in zip(self.logits, self.expected_probs): # Test outputs. actual = softmax(logits) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_softmax_probs_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.epsilon_softmax(epsilon=0.1, temperature=10.) # Vmap and optionally compile. softmax = compile_fn(distrib.probs) # Optionally convert to device array. logits = place_fn(self.logits) # Test softmax output in batch. actual = softmax(logits) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
def test_softmax_probs(self, compile_fn, place_fn): """Tests for a single element.""" distrib = distributions.epsilon_softmax(epsilon=0.1, temperature=10.) # Optionally compile. softmax = compile_fn(distrib.probs) # For each element in the batch. for logits, expected in zip(self.logits, self.expected_probs): # Optionally convert to device array. logits = place_fn(logits) # Test outputs. actual = softmax(logits) np.testing.assert_allclose(expected, actual, atol=1e-4)