def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) guide(**kwargs) neutra = NeuTraReparam(guide) reparam_model = neutra.reparam(model) _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model( reparam_model, model_kwargs=kwargs ) latent_x = list(init_params.values())[0] transformed_params = neutra.transform_sample(latent_x) pe_transformed = pe_fn_neutra(init_params) neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms) latent_y = neutra_transform(latent_x) log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()}) assert_close(pe_transformed, pe - log_det_jacobian)
def test_neals_funnel_smoke(jit): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step(dim) neutra = NeuTraReparam(guide.requires_grad_(False)) model = neutra.reparam(neals_funnel) nuts = NUTS(model, jit_compile=jit) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run(dim) samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze transformed_samples = neutra.transform_sample( samples['y_shared_latent'].unsqueeze(-2)) assert 'x' in transformed_samples assert 'y' in transformed_samples
def test_neals_funnel_smoke(): dim = 10 def model(): y = pyro.sample('y', dist.Normal(0, 3)) with pyro.plate("D", dim): pyro.sample('x', dist.Normal(0, torch.exp(y/2))) guide = AutoIAFNormal(model) svi = SVI(model, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step() neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run() samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2)) assert 'x' in transformed_samples assert 'y' in transformed_samples