def test_eight_schools(self):
    treatment_stddevs = jnp.array([15, 10, 16, 11, 9, 11, 10, 18],
                                  dtype=jnp.float32)

    def eight_schools(key):
      ae_key, as_key, se_key, te_key = random.split(key, 4)
      avg_effect = ppl.random_variable(
          tfd.Normal(loc=0., scale=10.), name='avg_effect')(
              ae_key)
      avg_stddev = ppl.random_variable(
          tfd.Normal(loc=5., scale=1.), name='avg_stddev')(
              as_key)
      school_effects_standard = ppl.random_variable(
          tfd.Independent(
              tfd.Normal(loc=jnp.zeros(8), scale=jnp.ones(8)),
              reinterpreted_batch_ndims=1),
          name='se_standard')(
              se_key)
      treatment_effects = ppl.random_variable(
          tfd.Independent(
              tfd.Normal(
                  loc=(avg_effect[..., jnp.newaxis] +
                       jnp.exp(avg_stddev[..., jnp.newaxis]) *
                       school_effects_standard),
                  scale=treatment_stddevs),
              reinterpreted_batch_ndims=1),
          name='te')(
              te_key)
      return treatment_effects

    jd_sample = ppl.joint_sample(eight_schools)
    jd_log_prob = core.log_prob(jd_sample)
    jd_log_prob(jd_sample(random.PRNGKey(0)))
    def test_joint_distribution(self):
        def model(key):
            k1, k2 = random.split(key)
            z = ppl.random_variable(bd.Normal(0., 1.), name='z')(k1)
            x = ppl.random_variable(bd.Normal(z, 1.), name='x')(k2)
            return x

        with self.assertRaises(ValueError):
            core.log_prob(model)(0.1)
        sample = ppl.joint_sample(model)
        self.assertEqual(
            core.log_prob(sample)({
                'z': 1.,
                'x': 2.
            }),
            bd.Normal(0., 1.).log_prob(1.) + bd.Normal(1., 1.).log_prob(2.))
  def test_bnn(self):

    def dense(dim_out, name):

      def forward(key, x):
        dim_in = x.shape[-1]
        w_key, b_key = random.split(key)
        w = ppl.random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
            name=f'{name}_w')(
                w_key)
        b = ppl.random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
            name=f'{name}_b')(
                b_key)
        return jnp.dot(w, x) + b

      return forward

    def mlp(hidden_size, num_layers):

      def forward(key, x):
        for i in range(num_layers):
          key, subkey = random.split(key)
          x = dense(hidden_size, 'dense_{}'.format(i + 1))(subkey, x)
          x = jax.nn.relu(x)
        logits = dense(10, 'dense_{}'.format(num_layers + 1))(key, x)
        return logits

      return forward

    def predict(mlp):

      def forward(key, x):
        k1, k2 = random.split(key)
        logits = mlp(k1, x)
        return tfd.Categorical(logits=logits).sample(seed=k2, name='y')

      return forward

    sample = ppl.joint_sample(predict(mlp(200, 2)))
    core.log_prob(sample)(sample(random.PRNGKey(0), jnp.ones(784)),
                          jnp.ones(784))