def test_gaussian_logprob_batch(self, variant): """Tests for a full batch.""" distrib = distributions.gaussian_diagonal() logprob_fn = variant(distrib.logprob) # Test greedy output in batch. actual = logprob_fn(self.sample, self.mu, self.sigma) np.testing.assert_allclose(self.expected_logprob_a, actual, atol=1e-4)
def test_gaussian_entropy_batch(self, variant): """Tests for a full batch.""" distrib = distributions.gaussian_diagonal() entropy_fn = variant(distrib.entropy) # Test greedy output in batch. actual = entropy_fn(self.mu, self.sigma) np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
def test_gaussian_kl_to_std_normal_batch(self): """Tests for a full batch.""" distrib = distributions.gaussian_diagonal() kl_fn = self.variant(distrib.kl_to_standard_normal) # Test greedy output in batch. actual = kl_fn(self.mu, self.sigma) np.testing.assert_allclose(self.expected_kl_to_std_normal, actual, atol=1e-4)
def test_gaussian_entropy(self, variant): """Tests for a single element.""" distrib = distributions.gaussian_diagonal() entropy_fn = variant(distrib.entropy) # For each element in the batch. for mu, sigma, expected in zip(self.mu, self.sigma, self.expected_entropy): # Test outputs. actual = entropy_fn(mu, sigma) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_gaussian_entropy_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.gaussian_diagonal() # Vmap and optionally compile. entropy_fn = compile_fn(distrib.entropy) # Optionally convert to device array. mu, sigma = tree_map(place_fn, (self.mu, self.sigma)) # Test greedy output in batch. actual = entropy_fn(mu, sigma) np.testing.assert_allclose(self.expected_entropy, actual, atol=1e-4)
def test_gaussian_logprob_batch(self, compile_fn, place_fn): """Tests for a full batch.""" distrib = distributions.gaussian_diagonal() # Vmap and optionally compile. logprob_fn = compile_fn(distrib.logprob) # Optionally convert to device array. mu, sigma, sample = tree_map(place_fn, (self.mu, self.sigma, self.sample)) # Test greedy output in batch. actual = logprob_fn(sample, mu, sigma) np.testing.assert_allclose(self.expected_logprob_a, actual, atol=1e-4)
def test_gaussian_logprob(self): """Tests for a single element.""" distrib = distributions.gaussian_diagonal() logprob_fn = self.variant(distrib.logprob) # For each element in the batch. for mu, sigma, sample, expected in zip( self.mu, self.sigma, self.sample, self.expected_logprob_a): # Test output. actual = logprob_fn(sample, mu, sigma) np.testing.assert_allclose(expected, actual, atol=1e-4)
def test_gaussian_entropy(self, compile_fn, place_fn): """Tests for a single element.""" distrib = distributions.gaussian_diagonal() # Optionally compile. entropy_fn = compile_fn(distrib.entropy) # For each element in the batch. for mu, sigma, sample, expected in zip( self.mu, self.sigma, self.sample, self.expected_entropy): # Optionally convert to device array. mu, sigma, sample = tree_map(place_fn, (mu, sigma, sample)) # Test outputs. actual = entropy_fn(mu, sigma) np.testing.assert_allclose(expected, actual, atol=1e-4)
_EPSILON_BOUND = 0.01 _EPSILON_MEAN_BOUND = 10.0 _EPSILON_COVARIANCE_BOUND = 1e-12 _NUM_ITERATIONS = 5000 _TARGET_UPDATE_PERIOD = 100 _RANDOM_SEED = 42 # The offset to ensure initially the policy is not close to 0 _MEAN_OFFSET = 2.0 # The final action should optimize down to be close to 0.0 _MAX_ACTION_ERROR = 0.2 _MAX_KL_ERROR = 1e-6 _DIAGONAL_GAUSSIAN_DIST = distributions.gaussian_diagonal() _PROJECTION_OPERATOR = functools.partial(jnp.clip, a_min=1e-10) def _hk_mock_policy_params(s_tm1): """Returns mock policy params.""" # Outputs of the network are mu and sigma. Both shaped [B, ACTION_DIM]. pi_out = hk.nets.MLP( output_sizes=[2 * ACTION_DIM], w_init=hk.initializers.VarianceScaling(1e-3), activation=jnp.tanh, activate_final=False, name='online_policy')(s_tm1) pi_mean, pi_cov = jnp.split(pi_out, 2, axis=-1) pi_cov = jax.nn.softplus(pi_cov) pi_mean = pi_mean + _MEAN_OFFSET