def test_independent(): from numpyro.contrib.tfp import distributions as tfd d = tfd.Independent(tfd.Normal(jnp.zeros(10), jnp.ones(10)), reinterpreted_batch_ndims=1) assert d.event_shape == (10, ) assert d.batch_shape == ()
def test_sample_tfp_distributions(): from numpyro.contrib.tfp import distributions as tfd # test no error raised d = tfd.Normal(0, 1) with numpyro.handlers.seed(rng_seed=random.PRNGKey(0)): numpyro.sample("normal", d) # test intermediates are [] value, intermediates = d(sample_intermediates=True, rng_key=random.PRNGKey(0)) assert intermediates == []
def test_transformed_distributions(): from tensorflow_probability.substrates.jax import bijectors as tfb from numpyro.contrib.tfp import distributions as tfd d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform()) d1 = tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp()) d2 = dist.TransformedDistribution(dist.Normal(0, 1), tfd.BijectorTransform(tfb.Exp())) x = random.normal(random.PRNGKey(0), (1000,)) d_x = d.log_prob(x).sum() d1_x = d1.log_prob(x).sum() d2_x = d2.log_prob(x).sum() assert_allclose(d_x, d1_x) assert_allclose(d_x, d2_x)
def f(x): return tfd.Normal(x, 1)