Example #1
0
 def model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     with reparam(config={'loc': TransformReparam()}):
         loc = numpyro.sample('loc', dist.TransformedDistribution(
             dist.Uniform(0, 1).mask(False),
             AffineTransform(0, alpha)))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
Example #2
0
def test_model_with_transformed_distribution():
    x_prior = dist.HalfNormal(2)
    y_prior = dist.LogNormal(scale=3.)  # transformed distribution

    def model():
        numpyro.sample('x', x_prior)
        numpyro.sample('y', y_prior)

    params = {'x': jnp.array(-5.), 'y': jnp.array(7.)}
    model = handlers.seed(model, random.PRNGKey(0))
    inv_transforms = {
        'x': biject_to(x_prior.support),
        'y': biject_to(y_prior.support)
    }
    expected_samples = partial(transform_fn, inv_transforms)(params)
    expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) -
                                 y_prior.log_prob(expected_samples['y']) -
                                 inv_transforms['x'].log_abs_det_jacobian(
                                     params['x'], expected_samples['x']) -
                                 inv_transforms['y'].log_abs_det_jacobian(
                                     params['y'], expected_samples['y']))

    reparam_model = reparam(model, {'y': TransformReparam()})
    base_params = {'x': params['x'], 'y_base': params['y']}
    actual_samples = constrain_fn(handlers.seed(reparam_model,
                                                random.PRNGKey(0)), (), {},
                                  base_params,
                                  return_deterministic=True)
    actual_potential_energy = potential_energy(reparam_model, (), {},
                                               base_params)

    assert_allclose(expected_samples['x'], actual_samples['x'])
    assert_allclose(expected_samples['y'], actual_samples['y'])
    assert_allclose(actual_potential_energy, expected_potential_energy)
Example #3
0
def reparam_model(dim=10):
    y = numpyro.sample('y', dist.Normal(0, 3))
    with reparam(config={'x': TransformReparam()}):
        numpyro.sample(
            'x',
            dist.TransformedDistribution(dist.Normal(jnp.zeros(dim - 1), 1),
                                         AffineTransform(0, jnp.exp(y / 2))))
Example #4
0
def test_log_normal(shape):
    loc = np.random.rand(*shape) * 2 - 1
    scale = np.random.rand(*shape) + 0.5

    def model():
        with numpyro.plate_stack("plates", shape):
            with numpyro.plate("particles", 100000):
                return numpyro.sample("x",
                                      dist.TransformedDistribution(
                                          dist.Normal(jnp.zeros_like(loc),
                                                      jnp.ones_like(scale)),
                                          [AffineTransform(loc, scale),
                                           ExpTransform()]).expand_by([100000]))

    with handlers.trace() as tr:
        value = handlers.seed(model, 0)()
    expected_moments = get_moments(value)

    with reparam(config={"x": TransformReparam()}):
        with handlers.trace() as tr:
            value = handlers.seed(model, 0)()
    assert tr["x"]["type"] == "deterministic"
    actual_moments = get_moments(value)
    assert_allclose(actual_moments, expected_moments, atol=0.05)
Example #5
0
 def model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     with reparam(config={'loc': TransformReparam()}):
         loc = numpyro.sample('loc', dist.Uniform(0, alpha))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)