def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': jnp.array(-5.), 'y': jnp.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) reparam_model = reparam(model, {'y': TransformReparam()}) base_params = {'x': params['x'], 'y_base': params['y']} actual_samples = constrain_fn(handlers.seed(reparam_model, random.PRNGKey(0)), (), {}, base_params, return_deterministic=True) actual_potential_energy = potential_energy(reparam_model, (), {}, base_params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def reparam_model(dim=10): y = numpyro.sample('y', dist.Normal(0, 3)) with reparam(config={'x': TransformReparam()}): numpyro.sample( 'x', dist.TransformedDistribution(dist.Normal(jnp.zeros(dim - 1), 1), AffineTransform(0, jnp.exp(y / 2))))
def test_log_normal(shape): loc = np.random.rand(*shape) * 2 - 1 scale = np.random.rand(*shape) + 0.5 def model(): with numpyro.plate_stack("plates", shape): with numpyro.plate("particles", 100000): return numpyro.sample("x", dist.TransformedDistribution( dist.Normal(jnp.zeros_like(loc), jnp.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]).expand_by([100000])) with handlers.trace() as tr: value = handlers.seed(model, 0)() expected_moments = get_moments(value) with reparam(config={"x": TransformReparam()}): with handlers.trace() as tr: value = handlers.seed(model, 0)() assert tr["x"]["type"] == "deterministic" actual_moments = get_moments(value) assert_allclose(actual_moments, expected_moments, atol=0.05)
def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)