コード例 #1
0
 def test_gaussian_logprob_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.gaussian_diagonal()
     logprob_fn = variant(distrib.logprob)
     # Test greedy output in batch.
     actual = logprob_fn(self.sample, self.mu, self.sigma)
     np.testing.assert_allclose(self.expected_logprob_a, actual, atol=1e-4)
コード例 #2
0
 def test_gaussian_entropy_batch(self, variant):
     """Tests for a full batch."""
     distrib = distributions.gaussian_diagonal()
     entropy_fn = variant(distrib.entropy)
     # Test greedy output in batch.
     actual = entropy_fn(self.mu, self.sigma)
     np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
コード例 #3
0
 def test_gaussian_kl_to_std_normal_batch(self):
     """Tests for a full batch."""
     distrib = distributions.gaussian_diagonal()
     kl_fn = self.variant(distrib.kl_to_standard_normal)
     # Test greedy output in batch.
     actual = kl_fn(self.mu, self.sigma)
     np.testing.assert_allclose(self.expected_kl_to_std_normal,
                                actual,
                                atol=1e-4)
コード例 #4
0
 def test_gaussian_entropy(self, variant):
     """Tests for a single element."""
     distrib = distributions.gaussian_diagonal()
     entropy_fn = variant(distrib.entropy)
     # For each element in the batch.
     for mu, sigma, expected in zip(self.mu, self.sigma,
                                    self.expected_entropy):
         # Test outputs.
         actual = entropy_fn(mu, sigma)
         np.testing.assert_allclose(expected, actual, atol=1e-4)
コード例 #5
0
 def test_gaussian_entropy_batch(self, compile_fn, place_fn):
     """Tests for a full batch."""
     distrib = distributions.gaussian_diagonal()
     # Vmap and optionally compile.
     entropy_fn = compile_fn(distrib.entropy)
     # Optionally convert to device array.
     mu, sigma = tree_map(place_fn, (self.mu, self.sigma))
     # Test greedy output in batch.
     actual = entropy_fn(mu, sigma)
     np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
コード例 #6
0
ファイル: distributions_test.py プロジェクト: wwxFromTju/rlax
 def test_gaussian_logprob_batch(self, compile_fn, place_fn):
   """Tests for a full batch."""
   distrib = distributions.gaussian_diagonal()
   # Vmap and optionally compile.
   logprob_fn = compile_fn(distrib.logprob)
   # Optionally convert to device array.
   mu, sigma, sample = tree_map(place_fn, (self.mu, self.sigma, self.sample))
   # Test greedy output in batch.
   actual = logprob_fn(sample, mu, sigma)
   np.testing.assert_allclose(self.expected_logprob_a, actual, atol=1e-4)
コード例 #7
0
 def test_gaussian_logprob(self):
   """Tests for a single element."""
   distrib = distributions.gaussian_diagonal()
   logprob_fn = self.variant(distrib.logprob)
   # For each element in the batch.
   for mu, sigma, sample, expected in zip(
       self.mu, self.sigma, self.sample, self.expected_logprob_a):
     # Test output.
     actual = logprob_fn(sample, mu, sigma)
     np.testing.assert_allclose(expected, actual, atol=1e-4)
コード例 #8
0
ファイル: distributions_test.py プロジェクト: wwxFromTju/rlax
 def test_gaussian_entropy(self, compile_fn, place_fn):
   """Tests for a single element."""
   distrib = distributions.gaussian_diagonal()
   # Optionally compile.
   entropy_fn = compile_fn(distrib.entropy)
   # For each element in the batch.
   for mu, sigma, sample, expected in zip(
       self.mu, self.sigma, self.sample, self.expected_entropy):
     # Optionally convert to device array.
     mu, sigma, sample = tree_map(place_fn, (mu, sigma, sample))
     # Test outputs.
     actual = entropy_fn(mu, sigma)
     np.testing.assert_allclose(expected, actual, atol=1e-4)
コード例 #9
0
_EPSILON_BOUND = 0.01
_EPSILON_MEAN_BOUND = 10.0
_EPSILON_COVARIANCE_BOUND = 1e-12

_NUM_ITERATIONS = 5000
_TARGET_UPDATE_PERIOD = 100
_RANDOM_SEED = 42

# The offset to ensure initially the policy is not close to 0
_MEAN_OFFSET = 2.0

# The final action should optimize down to be close to 0.0
_MAX_ACTION_ERROR = 0.2
_MAX_KL_ERROR = 1e-6

_DIAGONAL_GAUSSIAN_DIST = distributions.gaussian_diagonal()
_PROJECTION_OPERATOR = functools.partial(jnp.clip, a_min=1e-10)


def _hk_mock_policy_params(s_tm1):
  """Returns mock policy params."""
  # Outputs of the network are mu and sigma. Both shaped [B, ACTION_DIM].
  pi_out = hk.nets.MLP(
      output_sizes=[2 * ACTION_DIM],
      w_init=hk.initializers.VarianceScaling(1e-3),
      activation=jnp.tanh,
      activate_final=False,
      name='online_policy')(s_tm1)
  pi_mean, pi_cov = jnp.split(pi_out, 2, axis=-1)
  pi_cov = jax.nn.softplus(pi_cov)
  pi_mean = pi_mean + _MEAN_OFFSET