def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) guide(**kwargs) neutra = NeuTraReparam(guide) reparam_model = neutra.reparam(model) _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model( reparam_model, model_kwargs=kwargs ) latent_x = list(init_params.values())[0] transformed_params = neutra.transform_sample(latent_x) pe_transformed = pe_fn_neutra(init_params) neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms) latent_y = neutra_transform(latent_x) log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()}) assert_close(pe_transformed, pe - log_det_jacobian)
def test_neals_funnel_smoke(jit): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step(dim) neutra = NeuTraReparam(guide.requires_grad_(False)) model = neutra.reparam(neals_funnel) nuts = NUTS(model, jit_compile=jit) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run(dim) samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze transformed_samples = neutra.transform_sample( samples['y_shared_latent'].unsqueeze(-2)) assert 'x' in transformed_samples assert 'y' in transformed_samples
def test_neals_funnel_smoke(): dim = 10 def model(): y = pyro.sample('y', dist.Normal(0, 3)) with pyro.plate("D", dim): pyro.sample('x', dist.Normal(0, torch.exp(y/2))) guide = AutoIAFNormal(model) svi = SVI(model, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) for _ in range(1000): svi.step() neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model) mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) mcmc.run() samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; # hence the unsqueeze transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2)) assert 'x' in transformed_samples assert 'y' in transformed_samples
def main(args): pyro.set_rng_seed(args.rng_seed) fig = plt.figure(figsize=(8, 16), constrained_layout=True) gs = GridSpec(4, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[2, 0]) ax5 = fig.add_subplot(gs[3, 0]) ax6 = fig.add_subplot(gs[1, 1]) ax7 = fig.add_subplot(gs[2, 1]) ax8 = fig.add_subplot(gs[3, 1]) xlim = tuple(int(x) for x in args.x_lim.strip().split(',')) ylim = tuple(int(x) for x in args.y_lim.strip().split(',')) assert len(xlim) == 2 assert len(ylim) == 2 # 1. Plot samples drawn from BananaShaped distribution x1, x2 = torch.meshgrid( [torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)]) d = BananaShaped(args.param_a, args.param_b) p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1))) ax1.contourf( x1, x2, p, cmap='OrRd', ) ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='BananaShaped distribution: \nlog density') # 2. Run vanilla HMC logging.info('\nDrawing samples using vanilla HMC ...') mcmc = run_hmc(args, model) vanilla_samples = mcmc.get_samples()['x'].cpu().numpy() ax2.contourf(x1, x2, p, cmap='OrRd') ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(vanilla HMC)') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2) # 3(a). Fit a diagonal normal autoguide logging.info('\nFitting a DiagNormal autoguide ...') guide = AutoDiagonalNormal(model, init_scale=0.05) fit_guide(guide, args) with pyro.plate('N', args.num_samples): guide_samples = guide()['x'].detach().cpu().numpy() ax3.contourf(x1, x2, p, cmap='OrRd') ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(DiagNormal autoguide)') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3) # 3(b). Draw samples using NeuTra HMC logging.info( '\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...') neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) zs = mcmc.get_samples()['x_shared_latent'] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4) ax4.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)') samples = neutra.transform_sample(zs) samples = samples['x'].cpu().numpy() ax5.contourf(x1, x2, p, cmap='OrRd') ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)') sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5) # 4(a). Fit a BNAF autoguide logging.info('\nFitting a BNAF autoguide ...') guide = AutoNormalizingFlow( model, partial(iterated, args.num_flows, block_autoregressive)) fit_guide(guide, args) with pyro.plate('N', args.num_samples): guide_samples = guide()['x'].detach().cpu().numpy() ax6.contourf(x1, x2, p, cmap='OrRd') ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(BNAF autoguide)') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6) # 4(b). Draw samples using NeuTra HMC logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...') neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) zs = mcmc.get_samples()['x_shared_latent'] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7) ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)') samples = neutra.transform_sample(zs) samples = samples['x'].cpu().numpy() ax8.contourf(x1, x2, p, cmap='OrRd') ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior (transformed) \n(BNAF + NeuTra HMC)') sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8) plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))