def eight_schools(key): ae_key, as_key, se_key, te_key = random.split(key, 4) avg_effect = ppl.random_variable( bd.Normal(loc=0., scale=10.), name='avg_effect')( ae_key) avg_stddev = ppl.random_variable( bd.Normal(loc=5., scale=1.), name='avg_stddev')( as_key) school_effects_standard = ppl.random_variable( bd.Independent( bd.Normal(loc=np.zeros(8), scale=np.ones(8)), reinterpreted_batch_ndims=1), name='se_standard')( se_key) treatment_effects = ppl.random_variable( bd.Independent( bd.Normal( loc=(avg_effect[..., np.newaxis] + np.exp(avg_stddev[..., np.newaxis]) * school_effects_standard), scale=treatment_stddevs), reinterpreted_batch_ndims=1), name='te')( te_key) return treatment_effects
def test_normal_log_prob(self): def f(rng): return random_normal(rng) f_lp = log_prob(f) self.assertEqual(f_lp(0.), bd.Normal(0., 1.).log_prob(0.)) self.assertEqual(f_lp(1.), bd.Normal(0., 1.).log_prob(1.))
def forward(key, x): dim_in = x.shape[-1] w_key, b_key = random.split(key) w = ppl.random_variable(bd.Sample(bd.Normal(0., 1.), sample_shape=(dim_out, dim_in)), name=f'{name}_w')(w_key) b = ppl.random_variable(bd.Sample(bd.Normal(0., 1.), sample_shape=(dim_out, )), name=f'{name}_b')(b_key) return np.dot(w, x) + b
def test_log_normal_log_prob(self): def f(rng): return np.exp(random_normal(rng)) dist = bd.TransformedDistribution(bd.Normal(0., 1.), bb.Exp()) f_lp = log_prob(f) self.assertEqual(f_lp(2.), dist.log_prob(2.))
def test_conditional_log(self): def f(rng, x): return random_normal(rng) + x f_lp = log_prob(f) self.assertEqual( f_lp(0.1, 1.0), bd.Normal(0., 1.).log_prob(-0.9))
def test_joint_distribution(self): def model(key): k1, k2 = random.split(key) z = ppl.random_variable(bd.Normal(0., 1.), name='z')(k1) x = ppl.random_variable(bd.Normal(z, 1.), name='x')(k2) return x with self.assertRaises(ValueError): core.log_prob(model)(0.1) sample = ppl.joint_sample(model) self.assertEqual( core.log_prob(sample)({ 'z': 1., 'x': 2. }), bd.Normal(0., 1.).log_prob(1.) + bd.Normal(1., 1.).log_prob(2.))
def test_log_prob_in_call(self): def f(rng): z = call(lambda k: random_normal(k, name='z'))(rng) return z f_lp = log_prob(f) s = f(random.PRNGKey(0)) self.assertEqual(f_lp(s), bd.Normal(0., 1.).log_prob(s))
def random_normal_log_prob(_, outval, name=None): del name return bd.Normal(0., 1.).log_prob(outval)
def _sample(key, state): return ppl.random_variable( bd.Independent(bd.Normal(state, scale), reinterpreted_batch_ndims=np.ndim(state)))(key)
def _sample(key, s): return ppl.random_variable( bd.Sample(bd.Normal(0., 1.), sample_shape=s.shape))(key).astype(s.dtype)
def _sample(key, state): return ppl.random_variable( bd.Independent( # pytype: disable=module-attr bd.Normal(state, scale), # pytype: disable=module-attr reinterpreted_batch_ndims=np.ndim(state)))(key)
def _sample(key, s): return ppl.random_variable( bd.Sample( bd.Normal(0., 1.), # pytype: disable=module-attr sample_shape=s.shape))(key).astype(s.dtype)
def random_normal_log_prob_rule(incells, outcells, **_): outcell, = outcells if not outcell.top(): return incells, outcells, None outval = outcell.val return incells, outcells, bd.Normal(0., 1.).log_prob(outval)
def model(key): k1, k2 = random.split(key) z = ppl.random_variable(bd.Normal(0., 1.), name='z')(k1) x = ppl.random_variable(bd.Normal(z, 1.), name='x')(k2) return x
('normal_scalar_kwargs', bd.Normal, (), { 'loc': 0., 'scale': 1. }, 0., [0., 1.]), ('mvn_diag_args', bd.MultivariateNormalDiag, (onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)), {}, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('mvn_diag_kwargs', bd.MultivariateNormalDiag, (), { 'loc': onp.zeros(5, dtype=onp.float32), 'scale_diag': onp.ones(5, dtype=onp.float32) }, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('independent_normal_args', bd.Independent, (bd.Normal( onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)),), { 'reinterpreted_batch_ndims': 1 }, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('independent_normal_args2', bd.Independent, (bd.Normal( loc=onp.zeros(5, dtype=onp.float32), scale=onp.ones(5, dtype=onp.float32)),), { 'reinterpreted_batch_ndims': 1 }, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('independent_normal_kwargs', bd.Independent, (), { 'reinterpreted_batch_ndims': 1, 'distribution':
'loc': 0., 'scale': 1. }, 0., [0., 1.]), ('mvn_diag_args', bd.MultivariateNormalDiag, lambda: (onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)), lambda: {}, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('mvn_diag_kwargs', bd.MultivariateNormalDiag, lambda: (), lambda: { 'loc': onp.zeros(5, dtype=onp.float32), 'scale_diag': onp.ones(5, dtype=onp.float32) }, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('independent_normal_args', bd.Independent, lambda: (bd.Normal(onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32) ), ), lambda: { 'reinterpreted_batch_ndims': 1 }, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('independent_normal_args2', bd.Independent, lambda: (bd.Normal(loc=onp.zeros(5, dtype=onp.float32), scale=onp.ones(5, dtype=onp.float32)), ), lambda: { 'reinterpreted_batch_ndims': 1 }, onp.zeros(5, dtype=onp.float32), [onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)]), ('independent_normal_kwargs', bd.Independent, lambda: (), lambda: { 'reinterpreted_batch_ndims': 1, 'distribution':