def eight_schools(key): ae_key, as_key, se_key, te_key = random.split(key, 4) avg_effect = ppl.random_variable( bd.Normal(loc=0., scale=10.), name='avg_effect')( ae_key) avg_stddev = ppl.random_variable( bd.Normal(loc=5., scale=1.), name='avg_stddev')( as_key) school_effects_standard = ppl.random_variable( bd.Independent( bd.Normal(loc=np.zeros(8), scale=np.ones(8)), reinterpreted_batch_ndims=1), name='se_standard')( se_key) treatment_effects = ppl.random_variable( bd.Independent( bd.Normal( loc=(avg_effect[..., np.newaxis] + np.exp(avg_stddev[..., np.newaxis]) * school_effects_standard), scale=treatment_stddevs), reinterpreted_batch_ndims=1), name='te')( te_key) return treatment_effects
def _sample(key, state): return ppl.random_variable( bd.Independent(bd.Normal(state, scale), reinterpreted_batch_ndims=np.ndim(state)))(key)
def _sample(key, state): return ppl.random_variable( bd.Independent( # pytype: disable=module-attr bd.Normal(state, scale), # pytype: disable=module-attr reinterpreted_batch_ndims=np.ndim(state)))(key)