コード例 #1
0
ファイル: test_control_flow.py プロジェクト: dirmeier/numpyro
def test_scan_constrain_reparam_compatible():
    def model(T, q=1, r=1, phi=0.0, beta=0.0):
        x = 0.0
        mu = 0.0
        for i in range(T):
            x = numpyro.sample(f"x_{i}", dist.LogNormal(phi * x, q))
            mu = beta * mu + x
            numpyro.sample(f"y_{i}", dist.Normal(mu, r))

    def fun_model(T, q=1, r=1, phi=0.0, beta=0.0):
        def transition(state, i):
            x, mu = state
            x = numpyro.sample("x", dist.LogNormal(phi * x, q))
            mu = beta * mu + x
            numpyro.sample("y", dist.Normal(mu, r))
            return (x, mu), None

        scan(transition, (0.0, 0.0), jnp.arange(T))

    T = 10
    params = {}
    for i in range(T):
        params[f"x_{i}"] = (i + 1.0) / 10
        params[f"y_{i}"] = -i / 5
    fun_params = {"x": jnp.arange(1, T + 1) / 10, "y": -jnp.arange(T) / 5}
    actual_log_joint = potential_energy(fun_model, (T,), {}, fun_params)
    expected_log_joint = potential_energy(model, (T,), {}, params)
    assert_allclose(actual_log_joint, expected_log_joint)
コード例 #2
0
def test_model_with_transformed_distribution():
    x_prior = dist.HalfNormal(2)
    y_prior = dist.LogNormal(scale=3.)  # transformed distribution

    def model():
        numpyro.sample('x', x_prior)
        numpyro.sample('y', y_prior)

    params = {'x': jnp.array(-5.), 'y': jnp.array(7.)}
    model = handlers.seed(model, random.PRNGKey(0))
    inv_transforms = {
        'x': biject_to(x_prior.support),
        'y': biject_to(y_prior.support)
    }
    expected_samples = partial(transform_fn, inv_transforms)(params)
    expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) -
                                 y_prior.log_prob(expected_samples['y']) -
                                 inv_transforms['x'].log_abs_det_jacobian(
                                     params['x'], expected_samples['x']) -
                                 inv_transforms['y'].log_abs_det_jacobian(
                                     params['y'], expected_samples['y']))

    reparam_model = handlers.reparam(model, {'y': TransformReparam()})
    base_params = {'x': params['x'], 'y_base': params['y']}
    actual_samples = constrain_fn(handlers.seed(reparam_model,
                                                random.PRNGKey(0)), (), {},
                                  base_params,
                                  return_deterministic=True)
    actual_potential_energy = potential_energy(reparam_model, (), {},
                                               base_params)

    assert_allclose(expected_samples['x'], actual_samples['x'])
    assert_allclose(expected_samples['y'], actual_samples['y'])
    assert_allclose(actual_potential_energy, expected_potential_energy)
コード例 #3
0
ファイル: hmc_gibbs.py プロジェクト: uiuc-arc/numpyro
 def potential_fn(z_discrete):
     model_kwargs_ = model_kwargs.copy()
     model_kwargs_["_gibbs_sites"] = z_discrete
     return potential_energy(wrapped_model,
                             model_args,
                             model_kwargs_,
                             z_hmc,
                             enum=use_enum)