Ejemplo n.º 1
0
    def testKLVersusNormal(self):
        loc, scale = jnp.array([2.0]), jnp.array([2.0])
        log_scale = jnp.log(scale)
        lsn_prior = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0]))
        n_prior = normal.Normal(jnp.array([0.0]), jnp.array([1.0]))
        lsn_dist = lsn.LogStddevNormal(loc, log_scale)
        n_dist = normal.Normal(loc, scale)

        kl1 = tfp.distributions.kl_divergence(lsn_dist, lsn_prior)
        kl2 = tfp.distributions.kl_divergence(n_dist, lsn_prior)
        kl3 = tfp.distributions.kl_divergence(n_dist, n_prior)
        np.testing.assert_allclose(kl2, kl1)
        np.testing.assert_allclose(kl3, kl2)
        np.testing.assert_allclose(kl1, kl3)
Ejemplo n.º 2
0
    def test_jitable(self):
        @jax.jit
        def jitted_function(event, dist):
            return dist.log_prob(event)

        dist = lsn.LogStddevNormal(np.array([0.0]), np.array([0.0]))
        event = dist.sample(seed=jax.random.PRNGKey(0))
        jitted_function(event, dist)
Ejemplo n.º 3
0
    def testSamplingScalar(self, mean, stddev):
        log_stddev = np.log(stddev)
        dist = lsn.LogStddevNormal(mean, log_stddev)

        num_samples = 1000000
        prng_key = jax.random.PRNGKey(1331)
        samples = dist.sample(seed=prng_key, sample_shape=num_samples)
        chex.assert_shape(samples, (num_samples, ))
        np.testing.assert_allclose(jnp.mean(samples), mean, atol=4e-2)
        np.testing.assert_allclose(jnp.std(samples), stddev, atol=4e-2)
Ejemplo n.º 4
0
    def testSamplingVector(self, mean, stddev):
        mean = np.array(mean)
        log_stddev = np.log(stddev)
        assert mean.shape == log_stddev.shape
        dist = lsn.LogStddevNormal(mean, log_stddev)

        num_samples = 1000000
        prng_key = jax.random.PRNGKey(1331)
        samples = dist.sample(seed=prng_key, sample_shape=num_samples)
        chex.assert_shape(samples, (num_samples, ) + mean.shape)
        np.testing.assert_allclose(jnp.mean(samples, axis=0), mean, atol=4e-2)
        np.testing.assert_allclose(jnp.std(samples, axis=0), stddev, atol=4e-2)
Ejemplo n.º 5
0
    def testSamplingBatchedCustomDim(self):
        means = np.array([[3.0, 4.0], [-5, 48.0], [58, 64.0]])
        stddevs = np.array([[1, 2], [2, 4], [4, 8]])
        log_stddevs = np.log(stddevs)
        dist = lsn.LogStddevNormal(means, log_stddevs)

        num_samples = 1000000
        prng_key = jax.random.PRNGKey(1331)
        samples = dist.sample(seed=prng_key, sample_shape=num_samples)
        chex.assert_shape(samples, (num_samples, 3, 2))
        np.testing.assert_allclose(jnp.mean(samples, axis=0), means, atol=4e-2)
        np.testing.assert_allclose(jnp.std(samples, axis=0),
                                   stddevs,
                                   atol=4e-2)
Ejemplo n.º 6
0
 def test_log_scale_property(self, mean, log_stddev, expected):
     dist = lsn.LogStddevNormal(mean, log_stddev)
     assert dist.log_scale.shape == expected.shape
     np.testing.assert_allclose(dist.log_scale, expected, atol=1e-4)
Ejemplo n.º 7
0
 def testCallingCustomKL(self):
     # Check that the dispatch of tfp.kl_divergence actually goes to the
     # table we checked for above.
     dist_a = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0]))
     dist_b = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0]))
     self.assertEqual(tfp.distributions.kl_divergence(dist_a, dist_b), 42)
Ejemplo n.º 8
0
 def test_sample_dtype(self, dtype):
     dist = lsn.LogStddevNormal(loc=jnp.zeros((), dtype),
                                log_scale=jnp.zeros((), dtype))
     samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
     self.assertEqual(samples.dtype, dist.dtype)
     chex.assert_type(samples, dtype)