def eight_schools(key):
   ae_key, as_key, se_key, te_key = random.split(key, 4)
   avg_effect = ppl.random_variable(
       bd.Normal(loc=0., scale=10.), name='avg_effect')(
           ae_key)
   avg_stddev = ppl.random_variable(
       bd.Normal(loc=5., scale=1.), name='avg_stddev')(
           as_key)
   school_effects_standard = ppl.random_variable(
       bd.Independent(
           bd.Normal(loc=np.zeros(8), scale=np.ones(8)),
           reinterpreted_batch_ndims=1),
       name='se_standard')(
           se_key)
   treatment_effects = ppl.random_variable(
       bd.Independent(
           bd.Normal(
               loc=(avg_effect[..., np.newaxis] +
                    np.exp(avg_stddev[..., np.newaxis]) *
                    school_effects_standard),
               scale=treatment_stddevs),
           reinterpreted_batch_ndims=1),
       name='te')(
           te_key)
   return treatment_effects
Example #2
0
 def _sample(key, state):
     return ppl.random_variable(
         bd.Independent(bd.Normal(state, scale),
                        reinterpreted_batch_ndims=np.ndim(state)))(key)
Example #3
0
 def _sample(key, state):
     return ppl.random_variable(
         bd.Independent(  # pytype: disable=module-attr
             bd.Normal(state, scale),  # pytype: disable=module-attr
             reinterpreted_batch_ndims=np.ndim(state)))(key)