def test_eight_schools(self): treatment_stddevs = jnp.array([15, 10, 16, 11, 9, 11, 10, 18], dtype=jnp.float32) def eight_schools(key): ae_key, as_key, se_key, te_key = random.split(key, 4) avg_effect = ppl.random_variable( tfd.Normal(loc=0., scale=10.), name='avg_effect')( ae_key) avg_stddev = ppl.random_variable( tfd.Normal(loc=5., scale=1.), name='avg_stddev')( as_key) school_effects_standard = ppl.random_variable( tfd.Independent( tfd.Normal(loc=jnp.zeros(8), scale=jnp.ones(8)), reinterpreted_batch_ndims=1), name='se_standard')( se_key) treatment_effects = ppl.random_variable( tfd.Independent( tfd.Normal( loc=(avg_effect[..., jnp.newaxis] + jnp.exp(avg_stddev[..., jnp.newaxis]) * school_effects_standard), scale=treatment_stddevs), reinterpreted_batch_ndims=1), name='te')( te_key) return treatment_effects jd_sample = ppl.joint_sample(eight_schools) jd_log_prob = core.log_prob(jd_sample) jd_log_prob(jd_sample(random.PRNGKey(0)))
def test_joint_distribution(self): def model(key): k1, k2 = random.split(key) z = ppl.random_variable(bd.Normal(0., 1.), name='z')(k1) x = ppl.random_variable(bd.Normal(z, 1.), name='x')(k2) return x with self.assertRaises(ValueError): core.log_prob(model)(0.1) sample = ppl.joint_sample(model) self.assertEqual( core.log_prob(sample)({ 'z': 1., 'x': 2. }), bd.Normal(0., 1.).log_prob(1.) + bd.Normal(1., 1.).log_prob(2.))
def test_bnn(self): def dense(dim_out, name): def forward(key, x): dim_in = x.shape[-1] w_key, b_key = random.split(key) w = ppl.random_variable( tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)), name=f'{name}_w')( w_key) b = ppl.random_variable( tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)), name=f'{name}_b')( b_key) return jnp.dot(w, x) + b return forward def mlp(hidden_size, num_layers): def forward(key, x): for i in range(num_layers): key, subkey = random.split(key) x = dense(hidden_size, 'dense_{}'.format(i + 1))(subkey, x) x = jax.nn.relu(x) logits = dense(10, 'dense_{}'.format(num_layers + 1))(key, x) return logits return forward def predict(mlp): def forward(key, x): k1, k2 = random.split(key) logits = mlp(k1, x) return tfd.Categorical(logits=logits).sample(seed=k2, name='y') return forward sample = ppl.joint_sample(predict(mlp(200, 2))) core.log_prob(sample)(sample(random.PRNGKey(0), jnp.ones(784)), jnp.ones(784))