Exemplo n.º 1
0
def test_independent():
    from numpyro.contrib.tfp import distributions as tfd

    d = tfd.Independent(tfd.Normal(jnp.zeros(10), jnp.ones(10)),
                        reinterpreted_batch_ndims=1)
    assert d.event_shape == (10, )
    assert d.batch_shape == ()
Exemplo n.º 2
0
def test_sample_tfp_distributions():
    from numpyro.contrib.tfp import distributions as tfd

    # test no error raised
    d = tfd.Normal(0, 1)
    with numpyro.handlers.seed(rng_seed=random.PRNGKey(0)):
        numpyro.sample("normal", d)

    # test intermediates are []
    value, intermediates = d(sample_intermediates=True, rng_key=random.PRNGKey(0))
    assert intermediates == []
Exemplo n.º 3
0
def test_transformed_distributions():
    from tensorflow_probability.substrates.jax import bijectors as tfb
    from numpyro.contrib.tfp import distributions as tfd

    d = dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform())
    d1 = tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())
    d2 = dist.TransformedDistribution(dist.Normal(0, 1), tfd.BijectorTransform(tfb.Exp()))
    x = random.normal(random.PRNGKey(0), (1000,))
    d_x = d.log_prob(x).sum()
    d1_x = d1.log_prob(x).sum()
    d2_x = d2.log_prob(x).sum()
    assert_allclose(d_x, d1_x)
    assert_allclose(d_x, d2_x)
Exemplo n.º 4
0
 def f(x):
     return tfd.Normal(x, 1)