def test_greedy_logprob_batch(self, variant): """Tests for a full batch.""" distrib = distributions.epsilon_greedy(self.epsilon) logprob_fn = variant(distrib.logprob) # Test greedy output in batch. actual = logprob_fn(self.samples, self.preferences) np.testing.assert_allclose(self.expected_logprob, actual, atol=1e-4)
def test_greedy_probs_batch(self): """Tests for a full batch.""" distrib = distributions.epsilon_greedy(self.epsilon) probs_fn = self.variant(distrib.probs) # Test greedy output in batch. actual = probs_fn(self.preferences) np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
def test_greedy_entropy(self): """Tests for a single element.""" distrib = distributions.epsilon_greedy(self.epsilon) entropy_fn = self.variant(distrib.entropy) # For each element in the batch. for preferences, expected in zip(self.preferences, self.expected_entropy): # Test outputs. actual = entropy_fn(preferences) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_greedy_entropy_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.epsilon_greedy(self.epsilon) # Vmap and optionally compile. entropy_fn = compile_fn(distrib.entropy) # Optionally convert to device array. preferences = place_fn(self.preferences) # Test greedy output in batch. actual = entropy_fn(preferences) np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
def test_greedy_logprob_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.epsilon_greedy(self.epsilon) # Vmap and optionally compile. logprob_fn = compile_fn(distrib.logprob) # Optionally convert to device array. preferences, samples = tree_map(place_fn, (self.preferences, self.samples)) # Test greedy output in batch. actual = logprob_fn(samples, preferences) np.testing.assert_allclose(self.expected_logprob, actual, atol=1e-4)
def test_greedy_logprob(self): """Tests for a single element.""" distrib = distributions.epsilon_greedy(self.epsilon) logprob_fn = self.variant(distrib.logprob) # For each element in the batch. for preferences, samples, expected in zip( self.preferences, self.samples, self.expected_logprob): # Test output. actual = logprob_fn(samples, preferences) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_greedy_entropy(self, compile_fn, place_fn): """Tests for a single element.""" distrib = distributions.epsilon_greedy(self.epsilon) # Optionally compile. entropy_fn = compile_fn(distrib.entropy) # For each element in the batch. for preferences, expected in zip(self.preferences, self.expected_entropy): # Optionally convert to device array. preferences = place_fn(preferences) # Test outputs. actual = entropy_fn(preferences) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_greedy_logprob(self, compile_fn, place_fn): """Tests for a single element.""" distrib = distributions.epsilon_greedy(self.epsilon) # Optionally compile. logprob_fn = compile_fn(distrib.logprob) # For each element in the batch. for preferences, samples, expected in zip( self.preferences, self.samples, self.expected_logprob): # Optionally convert to device array. preferences, samples = tree_map(place_fn, (preferences, samples)) # Test output. actual = logprob_fn(samples, preferences) np.testing.assert_allclose(expected, actual, atol=1e-4)