def test_plate_reduces_over_named_axes(self):

    model = ppl.rv(tfd.Normal(0., 1.), plate='foo')
    out = jax.vmap(
        ppl.log_prob(model), axis_name='foo', out_axes=None)(
            jnp.arange(3.))
    np.testing.assert_allclose(
        tfd.Normal(0., 1.).log_prob(jnp.arange(3.)).sum(), out)
  def test_plate_produces_independent_samples(self):

    model = ppl.rv(tfd.Normal(0., 1.), plate='foo')
    out = jax.vmap(
        lambda _, key: model(key), in_axes=(0, None),
        axis_name='foo')(jnp.ones(3), random.PRNGKey(0))
    for i in range(3):
      for j in range(3):
        if i == j:
          continue
        self.assertNotAlmostEqual(out[i], out[j])
 def test_cannot_use_distribution_with_nontrivial_batch_shape(self):
   with self.assertRaises(ValueError):
     ppl.rv(tfd.Normal(jnp.ones(2), 1.))(random.PRNGKey(0))
 def f(key, x):
   return ppl.rv(tfd.Normal(x, 1.))(key)
 def model(key):
   return jax.vmap(ppl.rv(tfd.Normal(0., 1.)))(random.split(key))