예제 #1
0
 def test_greedy_logprob_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.epsilon_greedy(self.epsilon)
     logprob_fn = variant(distrib.logprob)
     # Test greedy output in batch.
     actual = logprob_fn(self.samples, self.preferences)
     np.testing.assert_allclose(self.expected_logprob, actual, atol=1e-4)
예제 #2
0
 def test_greedy_probs_batch(self):
     """Tests for a full batch."""
     distrib = distributions.epsilon_greedy(self.epsilon)
     probs_fn = self.variant(distrib.probs)
     # Test greedy output in batch.
     actual = probs_fn(self.preferences)
     np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
예제 #3
0
 def test_greedy_entropy(self):
   """Tests for a single element."""
   distrib = distributions.epsilon_greedy(self.epsilon)
   entropy_fn = self.variant(distrib.entropy)
   # For each element in the batch.
   for preferences, expected in zip(self.preferences, self.expected_entropy):
     # Test outputs.
     actual = entropy_fn(preferences)
     np.testing.assert_allclose(expected, actual, atol=1e-4)
예제 #4
0
 def test_greedy_entropy_batch(self, compile_fn, place_fn):
     """Tests for a full batch."""
     distrib = distributions.epsilon_greedy(self.epsilon)
     # Vmap and optionally compile.
     entropy_fn = compile_fn(distrib.entropy)
     # Optionally convert to device array.
     preferences = place_fn(self.preferences)
     # Test greedy output in batch.
     actual = entropy_fn(preferences)
     np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
예제 #5
0
 def test_greedy_logprob_batch(self, compile_fn, place_fn):
   """Tests for a full batch."""
   distrib = distributions.epsilon_greedy(self.epsilon)
   # Vmap and optionally compile.
   logprob_fn = compile_fn(distrib.logprob)
   # Optionally convert to device array.
   preferences, samples = tree_map(place_fn, (self.preferences, self.samples))
   # Test greedy output in batch.
   actual = logprob_fn(samples, preferences)
   np.testing.assert_allclose(self.expected_logprob, actual, atol=1e-4)
예제 #6
0
 def test_greedy_logprob(self):
   """Tests for a single element."""
   distrib = distributions.epsilon_greedy(self.epsilon)
   logprob_fn = self.variant(distrib.logprob)
   # For each element in the batch.
   for preferences, samples, expected in zip(
       self.preferences, self.samples, self.expected_logprob):
     # Test output.
     actual = logprob_fn(samples, preferences)
     np.testing.assert_allclose(expected, actual, atol=1e-4)
예제 #7
0
 def test_greedy_entropy(self, compile_fn, place_fn):
   """Tests for a single element."""
   distrib = distributions.epsilon_greedy(self.epsilon)
   # Optionally compile.
   entropy_fn = compile_fn(distrib.entropy)
   # For each element in the batch.
   for preferences, expected in zip(self.preferences, self.expected_entropy):
     # Optionally convert to device array.
     preferences = place_fn(preferences)
     # Test outputs.
     actual = entropy_fn(preferences)
     np.testing.assert_allclose(expected, actual, atol=1e-4)
예제 #8
0
 def test_greedy_logprob(self, compile_fn, place_fn):
   """Tests for a single element."""
   distrib = distributions.epsilon_greedy(self.epsilon)
   # Optionally compile.
   logprob_fn = compile_fn(distrib.logprob)
   # For each element in the batch.
   for preferences, samples, expected in zip(
       self.preferences, self.samples, self.expected_logprob):
     # Optionally convert to device array.
     preferences, samples = tree_map(place_fn, (preferences, samples))
     # Test output.
     actual = logprob_fn(samples, preferences)
     np.testing.assert_allclose(expected, actual, atol=1e-4)