def test_neutra_reparam_unobserved_model(): model = dirichlet_categorical data = jnp.ones(10, dtype=jnp.int32) guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-3), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), data) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) with handlers.seed(rng_seed=0): reparam_model(data=None)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) svi = SVI(model, guide, Adam(1e-10), Trace_ELBO(), **kwargs) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) reparam_model = neutra.reparam(model) _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs) latent_x = list(init_params[0].values())[0] pe_transformed = pe_fn_neutra(init_params[0]) latent_y = neutra.transform(latent_x) log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn(guide._unpack_latent(latent_y)) assert_allclose(pe_transformed, pe - log_det_jacobian)
def test_neals_funnel_smoke(): dim = 10 guide = AutoIAFNormal(neals_funnel) svi = SVI(neals_funnel, guide, Adam(1e-10), Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0), dim) def body_fn(i, val): svi_state, loss = svi.update(val, dim) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) neutra = NeuTraReparam(guide, params) model = neutra.reparam(neals_funnel) nuts = NUTS(model) mcmc = MCMC(nuts, num_warmup=50, num_samples=50) mcmc.run(random.PRNGKey(1), dim) samples = mcmc.get_samples() transformed_samples = neutra.transform_sample(samples['auto_shared_latent']) assert 'x' in transformed_samples assert 'y' in transformed_samples
def benchmark_hmc(args, features, labels): rng_key = random.PRNGKey(1) start = time.time() # a MAP estimate at the following source # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117 ref_params = { "coefs": jnp.array([ +2.03420663e00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01, -1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01, +9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01, -6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01, -2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01, -6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02, +1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02, -1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03, +6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01, +2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04, +3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02, +1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02, -9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01, -1.59496680e-01, -1.88516974e-01, -1.20889175e00, ]) } if args.algo == "HMC": step_size = jnp.sqrt(0.5 / features.shape[0]) trajectory_length = step_size * args.num_steps kernel = HMC( model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False, dense_mass=args.dense_mass, ) subsample_size = None elif args.algo == "NUTS": kernel = NUTS(model, dense_mass=args.dense_mass) subsample_size = None elif args.algo == "HMCECS": subsample_size = 1000 inner_kernel = NUTS( model, init_strategy=init_to_value(values=ref_params), dense_mass=args.dense_mass, ) # note: if num_blocks=100, we'll update 10 index at each MCMC step # so it took 50000 MCMC steps to iterative the whole dataset kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)) elif args.algo == "SA": # NB: this kernel requires large num_warmup and num_samples # and running on GPU is much faster than on CPU kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)) subsample_size = None elif args.algo == "FlowHMCECS": subsample_size = 1000 guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8]) svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(2), 2000, features, labels) params, losses = svi_result.params, svi_result.losses plt.plot(losses) plt.show() neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(model) neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)} # no need to adapt mass matrix if the flow does a good job inner_kernel = NUTS( neutra_model, init_strategy=init_to_value(values=neutra_ref_params), adapt_mass_matrix=False, ) kernel = HMCECS(inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)) else: raise ValueError( "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.") mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples) mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob", )) print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"])) mcmc.print_summary(exclude_deterministic=False) print("\nMCMC elapsed time:", time.time() - start)
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO()) print("Start training guide...") svi_result = svi.run(random.PRNGKey(1), args.num_iters) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), svi_result.params, sample_shape=(args.num_samples, ))['x'].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, svi_result.params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples['x'].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)['x'] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(svi_result.losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")