def test_scan_constrain_reparam_compatible(): def model(T, q=1, r=1, phi=0.0, beta=0.0): x = 0.0 mu = 0.0 for i in range(T): x = numpyro.sample(f"x_{i}", dist.LogNormal(phi * x, q)) mu = beta * mu + x numpyro.sample(f"y_{i}", dist.Normal(mu, r)) def fun_model(T, q=1, r=1, phi=0.0, beta=0.0): def transition(state, i): x, mu = state x = numpyro.sample("x", dist.LogNormal(phi * x, q)) mu = beta * mu + x numpyro.sample("y", dist.Normal(mu, r)) return (x, mu), None scan(transition, (0.0, 0.0), jnp.arange(T)) T = 10 params = {} for i in range(T): params[f"x_{i}"] = (i + 1.0) / 10 params[f"y_{i}"] = -i / 5 fun_params = {"x": jnp.arange(1, T + 1) / 10, "y": -jnp.arange(T) / 5} actual_log_joint = potential_energy(fun_model, (T,), {}, fun_params) expected_log_joint = potential_energy(model, (T,), {}, params) assert_allclose(actual_log_joint, expected_log_joint)
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': jnp.array(-5.), 'y': jnp.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) reparam_model = handlers.reparam(model, {'y': TransformReparam()}) base_params = {'x': params['x'], 'y_base': params['y']} actual_samples = constrain_fn(handlers.seed(reparam_model, random.PRNGKey(0)), (), {}, base_params, return_deterministic=True) actual_potential_energy = potential_energy(reparam_model, (), {}, base_params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum)