def test_plate_reduces_over_named_axes(self): model = ppl.rv(tfd.Normal(0., 1.), plate='foo') out = jax.vmap( ppl.log_prob(model), axis_name='foo', out_axes=None)( jnp.arange(3.)) np.testing.assert_allclose( tfd.Normal(0., 1.).log_prob(jnp.arange(3.)).sum(), out)
def test_plate_produces_independent_samples(self): model = ppl.rv(tfd.Normal(0., 1.), plate='foo') out = jax.vmap( lambda _, key: model(key), in_axes=(0, None), axis_name='foo')(jnp.ones(3), random.PRNGKey(0)) for i in range(3): for j in range(3): if i == j: continue self.assertNotAlmostEqual(out[i], out[j])
def test_cannot_use_distribution_with_nontrivial_batch_shape(self): with self.assertRaises(ValueError): ppl.rv(tfd.Normal(jnp.ones(2), 1.))(random.PRNGKey(0))
def f(key, x): return ppl.rv(tfd.Normal(x, 1.))(key)
def model(key): return jax.vmap(ppl.rv(tfd.Normal(0., 1.)))(random.split(key))