def eight_schools(key):
   ae_key, as_key, se_key, te_key = random.split(key, 4)
   avg_effect = ppl.random_variable(
       bd.Normal(loc=0., scale=10.), name='avg_effect')(
           ae_key)
   avg_stddev = ppl.random_variable(
       bd.Normal(loc=5., scale=1.), name='avg_stddev')(
           as_key)
   school_effects_standard = ppl.random_variable(
       bd.Independent(
           bd.Normal(loc=np.zeros(8), scale=np.ones(8)),
           reinterpreted_batch_ndims=1),
       name='se_standard')(
           se_key)
   treatment_effects = ppl.random_variable(
       bd.Independent(
           bd.Normal(
               loc=(avg_effect[..., np.newaxis] +
                    np.exp(avg_stddev[..., np.newaxis]) *
                    school_effects_standard),
               scale=treatment_stddevs),
           reinterpreted_batch_ndims=1),
       name='te')(
           te_key)
   return treatment_effects
    def test_normal_log_prob(self):
        def f(rng):
            return random_normal(rng)

        f_lp = log_prob(f)
        self.assertEqual(f_lp(0.), bd.Normal(0., 1.).log_prob(0.))
        self.assertEqual(f_lp(1.), bd.Normal(0., 1.).log_prob(1.))
 def forward(key, x):
     dim_in = x.shape[-1]
     w_key, b_key = random.split(key)
     w = ppl.random_variable(bd.Sample(bd.Normal(0., 1.),
                                       sample_shape=(dim_out,
                                                     dim_in)),
                             name=f'{name}_w')(w_key)
     b = ppl.random_variable(bd.Sample(bd.Normal(0., 1.),
                                       sample_shape=(dim_out, )),
                             name=f'{name}_b')(b_key)
     return np.dot(w, x) + b
    def test_log_normal_log_prob(self):
        def f(rng):
            return np.exp(random_normal(rng))

        dist = bd.TransformedDistribution(bd.Normal(0., 1.), bb.Exp())
        f_lp = log_prob(f)
        self.assertEqual(f_lp(2.), dist.log_prob(2.))
Пример #5
0
 def test_conditional_log(self):
   def f(rng, x):
     return random_normal(rng) + x
   f_lp = log_prob(f)
   self.assertEqual(
       f_lp(0.1, 1.0),
       bd.Normal(0., 1.).log_prob(-0.9))
    def test_joint_distribution(self):
        def model(key):
            k1, k2 = random.split(key)
            z = ppl.random_variable(bd.Normal(0., 1.), name='z')(k1)
            x = ppl.random_variable(bd.Normal(z, 1.), name='x')(k2)
            return x

        with self.assertRaises(ValueError):
            core.log_prob(model)(0.1)
        sample = ppl.joint_sample(model)
        self.assertEqual(
            core.log_prob(sample)({
                'z': 1.,
                'x': 2.
            }),
            bd.Normal(0., 1.).log_prob(1.) + bd.Normal(1., 1.).log_prob(2.))
    def test_log_prob_in_call(self):
        def f(rng):
            z = call(lambda k: random_normal(k, name='z'))(rng)
            return z

        f_lp = log_prob(f)
        s = f(random.PRNGKey(0))
        self.assertEqual(f_lp(s), bd.Normal(0., 1.).log_prob(s))
def random_normal_log_prob(_, outval, name=None):
    del name
    return bd.Normal(0., 1.).log_prob(outval)
Пример #9
0
 def _sample(key, state):
     return ppl.random_variable(
         bd.Independent(bd.Normal(state, scale),
                        reinterpreted_batch_ndims=np.ndim(state)))(key)
Пример #10
0
 def _sample(key, s):
     return ppl.random_variable(
         bd.Sample(bd.Normal(0., 1.),
                   sample_shape=s.shape))(key).astype(s.dtype)
Пример #11
0
 def _sample(key, state):
     return ppl.random_variable(
         bd.Independent(  # pytype: disable=module-attr
             bd.Normal(state, scale),  # pytype: disable=module-attr
             reinterpreted_batch_ndims=np.ndim(state)))(key)
Пример #12
0
 def _sample(key, s):
     return ppl.random_variable(
         bd.Sample(
             bd.Normal(0., 1.),  # pytype: disable=module-attr
             sample_shape=s.shape))(key).astype(s.dtype)
Пример #13
0
def random_normal_log_prob_rule(incells, outcells, **_):
  outcell, = outcells
  if not outcell.top():
    return incells, outcells, None
  outval = outcell.val
  return incells, outcells, bd.Normal(0., 1.).log_prob(outval)
 def model(key):
   k1, k2 = random.split(key)
   z = ppl.random_variable(bd.Normal(0., 1.), name='z')(k1)
   x = ppl.random_variable(bd.Normal(z, 1.), name='x')(k2)
   return x
 ('normal_scalar_kwargs', bd.Normal, (), {
     'loc': 0.,
     'scale': 1.
 }, 0., [0., 1.]),
 ('mvn_diag_args', bd.MultivariateNormalDiag,
  (onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)), {},
  onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('mvn_diag_kwargs', bd.MultivariateNormalDiag, (), {
     'loc': onp.zeros(5, dtype=onp.float32),
     'scale_diag': onp.ones(5, dtype=onp.float32)
 }, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('independent_normal_args', bd.Independent, (bd.Normal(
     onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)),), {
         'reinterpreted_batch_ndims': 1
     }, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('independent_normal_args2', bd.Independent, (bd.Normal(
     loc=onp.zeros(5, dtype=onp.float32),
     scale=onp.ones(5, dtype=onp.float32)),), {
         'reinterpreted_batch_ndims': 1
     }, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('independent_normal_kwargs', bd.Independent, (), {
     'reinterpreted_batch_ndims':
         1,
     'distribution':
     'loc': 0.,
     'scale': 1.
 }, 0., [0., 1.]),
 ('mvn_diag_args', bd.MultivariateNormalDiag, lambda:
  (onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)),
  lambda: {}, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('mvn_diag_kwargs', bd.MultivariateNormalDiag, lambda: (), lambda: {
     'loc': onp.zeros(5, dtype=onp.float32),
     'scale_diag': onp.ones(5, dtype=onp.float32)
 }, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('independent_normal_args', bd.Independent, lambda:
  (bd.Normal(onp.zeros(5, dtype=onp.float32), onp.ones(5, dtype=onp.float32)
             ), ), lambda: {
                 'reinterpreted_batch_ndims': 1
             }, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('independent_normal_args2', bd.Independent, lambda:
  (bd.Normal(loc=onp.zeros(5, dtype=onp.float32),
             scale=onp.ones(5, dtype=onp.float32)), ), lambda: {
                 'reinterpreted_batch_ndims': 1
             }, onp.zeros(5, dtype=onp.float32),
  [onp.zeros(5, dtype=onp.float32),
   onp.ones(5, dtype=onp.float32)]),
 ('independent_normal_kwargs', bd.Independent, lambda: (), lambda: {
     'reinterpreted_batch_ndims':
     1,
     'distribution':