def eight_schools(key): ae_key, as_key, se_key, te_key = random.split(key, 4) avg_effect = ppl.random_variable( tfd.Normal(loc=0., scale=10.), name='avg_effect')( ae_key) avg_stddev = ppl.random_variable( tfd.Normal(loc=5., scale=1.), name='avg_stddev')( as_key) school_effects_standard = ppl.random_variable( tfd.Independent( tfd.Normal(loc=jnp.zeros(8), scale=jnp.ones(8)), reinterpreted_batch_ndims=1), name='se_standard')( se_key) treatment_effects = ppl.random_variable( tfd.Independent( tfd.Normal( loc=(avg_effect[..., jnp.newaxis] + jnp.exp(avg_stddev[..., jnp.newaxis]) * school_effects_standard), scale=treatment_stddevs), reinterpreted_batch_ndims=1), name='te')( te_key) return treatment_effects
def forward(key, x): dim_in = x.shape[-1] w_key, b_key = random.split(key) w = ppl.random_variable(bd.Sample(bd.Normal(0., 1.), sample_shape=(dim_out, dim_in)), name=f'{name}_w')(w_key) b = ppl.random_variable(bd.Sample(bd.Normal(0., 1.), sample_shape=(dim_out, )), name=f'{name}_b')(b_key) return np.dot(w, x) + b
def wrapped(key): result = primitive.initial_style_bind( random_variable_p, distribution_name=dist.__class__.__name__)(_sample_distribution)( key, dist) if name is not None: result = ppl.random_variable(result, name=name) return result
def wrapped(key): def sample(key): result = primitive.initial_style_bind( random_variable_p, distribution_name=dist.__class__.__name__)( _sample_distribution)(key, dist) return result if name is not None: return ppl.random_variable(harvest.nest(sample, scope=name)(key), name=name) return sample(key)
def model(key): k1, k2 = random.split(key) z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(k1) x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(k2) return x
def model(key): return ppl.random_variable(p, name='x')(key)
def sample(key): return ppl.random_variable(p)(key)
def _sample(key, state): return ppl.random_variable( bd.Independent(bd.Normal(state, scale), reinterpreted_batch_ndims=np.ndim(state)))(key)
def _sample(key, s): return ppl.random_variable( bd.Sample(bd.Normal(0., 1.), sample_shape=s.shape))(key).astype(s.dtype)
def wrapped(key): result = primitive.call_bind(random_variable_p)(_sample_distribution)( key, dist) if name is not None: result = ppl.random_variable(result, name=name) return result
def _sample(key, state): return ppl.random_variable( bd.Independent( # pytype: disable=module-attr bd.Normal(state, scale), # pytype: disable=module-attr reinterpreted_batch_ndims=np.ndim(state)))(key)
def _sample(key, s): return ppl.random_variable( bd.Sample( bd.Normal(0., 1.), # pytype: disable=module-attr sample_shape=s.shape))(key).astype(s.dtype)