コード例 #1
0
 def test_softmax_probs_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.epsilon_softmax(epsilon=0.1, temperature=10.)
     softmax = variant(distrib.probs)
     # Test softmax output in batch.
     actual = softmax(self.logits)
     np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
コード例 #2
0
 def test_softmax_probs(self, variant):
     """Tests for a single element."""
     distrib = distributions.epsilon_softmax(epsilon=0.1, temperature=10.)
     softmax = variant(distrib.probs)
     # For each element in the batch.
     for logits, expected in zip(self.logits, self.expected_probs):
         # Test outputs.
         actual = softmax(logits)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
コード例 #3
0
ファイル: distributions_test.py プロジェクト: wwxFromTju/rlax
 def test_softmax_probs_batch(self, compile_fn, place_fn):
   """Tests for a full batch."""
   distrib = distributions.epsilon_softmax(epsilon=0.1,
                                           temperature=10.)
   # Vmap and optionally compile.
   softmax = compile_fn(distrib.probs)
   # Optionally convert to device array.
   logits = place_fn(self.logits)
   # Test softmax output in batch.
   actual = softmax(logits)
   np.testing.assert_allclose(self.expected_probs, actual, atol=1e-4)
コード例 #4
0
ファイル: distributions_test.py プロジェクト: wwxFromTju/rlax
 def test_softmax_probs(self, compile_fn, place_fn):
   """Tests for a single element."""
   distrib = distributions.epsilon_softmax(epsilon=0.1,
                                           temperature=10.)
   # Optionally compile.
   softmax = compile_fn(distrib.probs)
   # For each element in the batch.
   for logits, expected in zip(self.logits, self.expected_probs):
     # Optionally convert to device array.
     logits = place_fn(logits)
     # Test outputs.
     actual = softmax(logits)
     np.testing.assert_allclose(expected, actual, atol=1e-4)