def testKLVersusNormal(self): loc, scale = jnp.array([2.0]), jnp.array([2.0]) log_scale = jnp.log(scale) lsn_prior = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0])) n_prior = normal.Normal(jnp.array([0.0]), jnp.array([1.0])) lsn_dist = lsn.LogStddevNormal(loc, log_scale) n_dist = normal.Normal(loc, scale) kl1 = tfp.distributions.kl_divergence(lsn_dist, lsn_prior) kl2 = tfp.distributions.kl_divergence(n_dist, lsn_prior) kl3 = tfp.distributions.kl_divergence(n_dist, n_prior) np.testing.assert_allclose(kl2, kl1) np.testing.assert_allclose(kl3, kl2) np.testing.assert_allclose(kl1, kl3)
def test_jitable(self): @jax.jit def jitted_function(event, dist): return dist.log_prob(event) dist = lsn.LogStddevNormal(np.array([0.0]), np.array([0.0])) event = dist.sample(seed=jax.random.PRNGKey(0)) jitted_function(event, dist)
def testSamplingScalar(self, mean, stddev): log_stddev = np.log(stddev) dist = lsn.LogStddevNormal(mean, log_stddev) num_samples = 1000000 prng_key = jax.random.PRNGKey(1331) samples = dist.sample(seed=prng_key, sample_shape=num_samples) chex.assert_shape(samples, (num_samples, )) np.testing.assert_allclose(jnp.mean(samples), mean, atol=4e-2) np.testing.assert_allclose(jnp.std(samples), stddev, atol=4e-2)
def testSamplingVector(self, mean, stddev): mean = np.array(mean) log_stddev = np.log(stddev) assert mean.shape == log_stddev.shape dist = lsn.LogStddevNormal(mean, log_stddev) num_samples = 1000000 prng_key = jax.random.PRNGKey(1331) samples = dist.sample(seed=prng_key, sample_shape=num_samples) chex.assert_shape(samples, (num_samples, ) + mean.shape) np.testing.assert_allclose(jnp.mean(samples, axis=0), mean, atol=4e-2) np.testing.assert_allclose(jnp.std(samples, axis=0), stddev, atol=4e-2)
def testSamplingBatchedCustomDim(self): means = np.array([[3.0, 4.0], [-5, 48.0], [58, 64.0]]) stddevs = np.array([[1, 2], [2, 4], [4, 8]]) log_stddevs = np.log(stddevs) dist = lsn.LogStddevNormal(means, log_stddevs) num_samples = 1000000 prng_key = jax.random.PRNGKey(1331) samples = dist.sample(seed=prng_key, sample_shape=num_samples) chex.assert_shape(samples, (num_samples, 3, 2)) np.testing.assert_allclose(jnp.mean(samples, axis=0), means, atol=4e-2) np.testing.assert_allclose(jnp.std(samples, axis=0), stddevs, atol=4e-2)
def test_log_scale_property(self, mean, log_stddev, expected): dist = lsn.LogStddevNormal(mean, log_stddev) assert dist.log_scale.shape == expected.shape np.testing.assert_allclose(dist.log_scale, expected, atol=1e-4)
def testCallingCustomKL(self): # Check that the dispatch of tfp.kl_divergence actually goes to the # table we checked for above. dist_a = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0])) dist_b = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0])) self.assertEqual(tfp.distributions.kl_divergence(dist_a, dist_b), 42)
def test_sample_dtype(self, dtype): dist = lsn.LogStddevNormal(loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype)) samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0)) self.assertEqual(samples.dtype, dist.dtype) chex.assert_type(samples, dtype)