def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()["x"].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO()) print("Start training guide...") svi_result = svi.run(random.PRNGKey(1), args.num_iters) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), svi_result.params, sample_shape=(args.num_samples, ))["x"].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, svi_result.params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples["x"].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.0).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)["x"] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(svi_result.losses[1000:]) ax1.set_title("Autoguide training loss\n(after 1000 steps)") ax2.contourf(X1, X2, P, cmap="OrRd") sns.kdeplot(x=guide_samples[:, 0], y=guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set( xlim=[-3, 3], ylim=[-3, 3], xlabel="x0", ylabel="x1", title="Posterior using\nAutoBNAFNormal guide", ) sns.scatterplot( x=guide_base_samples[:, 0], y=guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.0, ) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel="x0", ylabel="x1", title="AutoBNAFNormal base samples\n(True=left moon; False=right moon)", ) ax4.contourf(X1, X2, P, cmap="OrRd") sns.kdeplot(x=vanilla_samples[:, 0], y=vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], "bo-", alpha=0.5) ax4.set( xlim=[-3, 3], ylim=[-3, 3], xlabel="x0", ylabel="x1", title="Posterior using\nvanilla HMC sampler", ) sns.scatterplot( x=zs[:, 0], y=zs[:, 1], ax=ax5, hue=samples[:, 0] < 0.0, s=30, alpha=0.5, edgecolor="none", ) ax5.set( xlim=[-5, 5], ylim=[-5, 5], xlabel="x0", ylabel="x1", title="Samples from the\nwarped posterior - p(z)", ) ax6.contourf(X1, X2, P, cmap="OrRd") sns.kdeplot(x=samples[:, 0], y=samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], "bo-", alpha=0.2) ax6.set( xlim=[-3, 3], ylim=[-3, 3], xlabel="x0", ylabel="x1", title="Posterior using\nNeuTra HMC sampler", ) plt.savefig("neutra.pdf")
def main(args): encoder_nn = encoder(args.hidden_dim, args.z_dim) decoder_nn = decoder(args.hidden_dim, 28 * 28) adam = optim.Adam(args.learning_rate) svi = SVI( model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim ) rng_key = PRNGKey(0) train_init, train_fetch = load_dataset( MNIST, batch_size=args.batch_size, split="train" ) test_init, test_fetch = load_dataset( MNIST, batch_size=args.batch_size, split="test" ) num_train, train_idx = train_init() rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3) sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0]) svi_state = svi.init(rng_key_init, sample_batch) @jit def epoch_train(svi_state, rng_key, train_idx): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0.0, svi_state)) @jit def eval_test(svi_state, rng_key, test_idx): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.0) loss = loss / num_test return loss def reconstruct_img(epoch, rng_key): img = test_fetch(0, test_idx)[0][0] plt.imsave( os.path.join(RESULTS_DIR, "original_epoch={}.png".format(epoch)), img, cmap="gray", ) rng_key_binarize, rng_key_sample = random.split(rng_key) test_sample = binarize(rng_key_binarize, img) params = svi.get_params(svi_state) z_mean, z_var = encoder_nn[1]( params["encoder$params"], test_sample.reshape([1, -1]) ) z = dist.Normal(z_mean, z_var).sample(rng_key_sample) img_loc = decoder_nn[1](params["decoder$params"], z).reshape([28, 28]) plt.imsave( os.path.join(RESULTS_DIR, "recons_epoch={}.png".format(epoch)), img_loc, cmap="gray", ) for i in range(args.num_epochs): rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split( rng_key, 4 ) t_start = time.time() num_train, train_idx = train_init() _, svi_state = epoch_train(svi_state, rng_key_train, train_idx) rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3) num_test, test_idx = test_init() test_loss = eval_test(svi_state, rng_key_test, test_idx) reconstruct_img(i, rng_key_reconstruct) print( "Epoch {}: loss = {} ({:.2f} s.)".format( i, test_loss, time.time() - t_start ) )
return Trace_ELBO().loss(random.PRNGKey(0), {}, model, guide, x) def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x) elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.0) renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.0) assert_allclose(elbo_loss, renyi_loss, rtol=1e-6) assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) @pytest.mark.parametrize( "optimizer", [optim.Adam(0.05), optimizers.adam(0.05)]) def test_beta_bernoulli(elbo, optimizer): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optimizer, elbo)
def test_tracegraph_gaussian_chain(num_latents, num_steps, step_size, atol, difficulty): loc0 = 0.2 data = jnp.array([-0.1, 0.03, 0.2, 0.1]) n_data = data.shape[0] sum_data = data.sum() N = num_latents lambdas = [1.5 * (k + 1) / N for k in range(N + 1)] lambdas = list(map(lambda x: jnp.array([x]), lambdas)) lambda_tilde_posts = [lambdas[0]] for k in range(1, N): lambda_tilde_k = (lambdas[k] * lambda_tilde_posts[k - 1]) / ( lambdas[k] + lambda_tilde_posts[k - 1]) lambda_tilde_posts.append(lambda_tilde_k) lambda_posts = [ None ] # this is never used (just a way of shifting the indexing by 1) for k in range(1, N): lambda_k = lambdas[k] + lambda_tilde_posts[k - 1] lambda_posts.append(lambda_k) lambda_N_post = (n_data * lambdas[N]) + lambda_tilde_posts[N - 1] lambda_posts.append(lambda_N_post) target_kappas = [None] target_kappas.extend([lambdas[k] / lambda_posts[k] for k in range(1, N)]) target_mus = [None] target_mus.extend([ loc0 * lambda_tilde_posts[k - 1] / lambda_posts[k] for k in range(1, N) ]) target_loc_N = (sum_data * lambdas[N] / lambda_N_post + loc0 * lambda_tilde_posts[N - 1] / lambda_N_post) target_mus.append(target_loc_N) np.random.seed(0) while True: mask = np.random.binomial(1, 0.3, (N, )) if mask.sum() < 0.4 * N and mask.sum() > 0.5: which_nodes_reparam = mask break class FakeNormal(dist.Normal): reparametrized_params = [] def model(difficulty=0.0): next_mean = loc0 for k in range(1, N + 1): latent_dist = dist.Normal(next_mean, jnp.power(lambdas[k - 1], -0.5)) loc_latent = numpyro.sample("loc_latent_{}".format(k), latent_dist) next_mean = loc_latent loc_N = next_mean with numpyro.plate("data", data.shape[0]): numpyro.sample("obs", dist.Normal(loc_N, jnp.power(lambdas[N], -0.5)), obs=data) return loc_N def guide(difficulty=0.0): previous_sample = None for k in reversed(range(1, N + 1)): loc_q = numpyro.param( f"loc_q_{k}", lambda key: target_mus[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) log_sig_q = numpyro.param( f"log_sig_q_{k}", lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty * (0.1 * random.normal(key) - 0.53), ) sig_q = jnp.exp(log_sig_q) kappa_q = None if k != N: kappa_q = numpyro.param( "kappa_q_%d" % k, lambda key: target_kappas[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False Normal = dist.Normal if node_flagged else FakeNormal loc_latent = numpyro.sample(f"loc_latent_{k}", Normal(mean_function, sig_q)) previous_sample = loc_latent return previous_sample adam = optim.Adam(step_size=step_size, b1=0.95, b2=0.999) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi_result = svi.run(jax.random.PRNGKey(0), num_steps, difficulty=difficulty) kappa_errors, log_sig_errors, loc_errors = [], [], [] for k in range(1, N + 1): if k != N: kappa_error = jnp.sum( jnp.power(svi_result.params[f"kappa_q_{k}"] - target_kappas[k], 2)) kappa_errors.append(kappa_error) loc_errors.append( jnp.sum( jnp.power(svi_result.params[f"loc_q_{k}"] - target_mus[k], 2))) log_sig_error = jnp.sum( jnp.power( svi_result.params[f"log_sig_q_{k}"] + 0.5 * jnp.log(lambda_posts[k]), 2)) log_sig_errors.append(log_sig_error) max_errors = (np.max(loc_errors), np.max(log_sig_errors), np.max(kappa_errors)) for i in range(3): assert_allclose(max_errors[i], 0, atol=atol)
def main(args): # loading data (train_init, train_fetch_plain), num_samples = load_dataset( MNIST, batch_size=args.batch_size, split='train', batchifier=subsample_batchify_data) (test_init, test_fetch_plain), _ = load_dataset(MNIST, batch_size=args.batch_size, split='test', batchifier=split_batchify_data) def binarize_fetch(fetch_fn): @jit def fetch_binarized(batch_nr, batchifier_state, binarize_rng): batch = fetch_fn(batch_nr, batchifier_state) return binarize(binarize_rng, batch[0]), batch[1] return fetch_binarized train_fetch = binarize_fetch(train_fetch_plain) test_fetch = binarize_fetch(test_fetch_plain) # setting up optimizer optimizer = optimizers.Adam(args.learning_rate) # the minibatch environment in our model scales individual # records' contributions to the loss up by num_samples. # This can cause numerical instabilities so we scale down # the loss by 1/num_samples here. sc_model = scale(model, scale=1 / num_samples) sc_guide = scale(guide, scale=1 / num_samples) if args.no_dp: svi = SVI(sc_model, sc_guide, optimizer, ELBO(), num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) else: q = args.batch_size / num_samples target_eps = args.epsilon dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps, delta=1 / num_samples, q=q, num_iter=int(1 / q) * args.num_epochs, force_smaller=True) print( f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})" ) svi = DPSVI(sc_model, sc_guide, optimizer, ELBO(), dp_scale=dp_scale, clipping_threshold=10., num_obs_total=num_samples, z_dim=args.z_dim, hidden_dim=args.hidden_dim) # preparing random number generators and initializing svi rng = PRNGKey(0) rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4) _, batchifier_state = train_init(rng_key=batchifier_rng) sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0] svi_state = svi.init(svi_init_rng, sample_batch) # functions for training tasks @jit def epoch_train(svi_state, batchifier_state, num_batches, rng): """Trains one epoch :param svi_state: current state of the optimizer :param rng: rng key :return: overall training loss over the epoch """ def body_fn(i, val): svi_state, loss = val binarize_rng = random.fold_in(rng, i) batch = train_fetch(i, batchifier_state, binarize_rng)[0] svi_state, batch_loss = svi.update(svi_state, batch) loss += batch_loss / num_batches return svi_state, loss svi_state, loss = lax.fori_loop(0, num_batches, body_fn, (svi_state, 0.)) return svi_state, loss @jit def eval_test(svi_state, batchifier_state, num_batches, rng): """Evaluates current model state on test data. :param svi_state: current state of the optimizer :param rng: rng key :return: loss over the test split """ def body_fn(i, loss_sum): binarize_rng = random.fold_in(rng, i) batch = test_fetch(i, batchifier_state, binarize_rng)[0] batch_loss = svi.evaluate(svi_state, batch) loss_sum += batch_loss / num_batches return loss_sum return lax.fori_loop(0, num_batches, body_fn, 0.) def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng): """Reconstructs an image for the given epoch Obtains a sample from the testing data set and passes it through the VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and the original input as 'epoch_{epoch}_original.png' in folder '.results'. :param epoch: Number of the current epoch :param num_epochs: Number of total epochs :param opt_state: Current state of the optimizer :param rng: rng key """ assert (num_epochs > 0) img = test_fetch_plain(0, batchifier_state)[0][0] plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_original.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img, cmap='gray') rng, rng_binarize = random.split(rng, 2) test_sample = binarize(rng_binarize, img) test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample))) params = svi.get_params(svi_state) samples = sample_multi_posterior_predictive( rng, 10, model, (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])), guide, (test_sample, args.z_dim, args.hidden_dim), params) img_loc = samples['obs'][0].reshape([28, 28]) avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28]) plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), img_loc, cmap='gray') plt.imsave(os.path.join( RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format( epoch, (int(jnp.log10(num_epochs)) + 1))), avg_img_loc, cmap='gray') # main training loop for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng, train_rng = random.split(rng, 3) num_train_batches, train_batchifier_state, = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches, train_rng) rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state, recons_rng) print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, time.time() - t_start))
def test_subsample_gradient(scale, subsample): data = jnp.array([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) precision = 0.06 * scale def model(subsample): with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size) as ind: x = data[ind] z = numpyro.sample("z", dist.Normal(0, 1)) numpyro.sample("x", dist.Normal(z, 1), obs=x) def guide(subsample): scale = numpyro.param("scale", 1.0) with handlers.substitute(data={"data": subsample}): with numpyro.plate("data", len(data), subsample_size): loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0) numpyro.sample("z", dist.Normal(loc, scale)) if scale != 1.0: model = handlers.scale(model, scale=scale) guide = handlers.scale(guide, scale=scale) num_particles = 50000 optimizer = optim.Adam(0.1) elbo = Trace_ELBO(num_particles=num_particles) svi = SVI(model, guide, optimizer, loss=elbo) svi_state = svi.init(random.PRNGKey(0), None) params = svi.optim.get_params(svi_state.optim_state) normalizer = 2 if subsample else 1 if subsample_size == 1: subsample = jnp.array([0]) loss1, grads1 = value_and_grad( lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample ) )(params) subsample = jnp.array([1]) loss2, grads2 = value_and_grad( lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample ) )(params) grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2) loss = loss1 + loss2 else: subsample = jnp.array([0, 1]) loss, grads = value_and_grad( lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample ) )(params) actual_loss = loss / normalizer expected_loss, _ = value_and_grad( lambda x: svi.loss.loss( svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None ) )(params) assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision) actual_grads = {name: grad / normalizer for name, grad in grads.items()} expected_grads = { "loc": scale * jnp.array([0.5, -2.0]), "scale": scale * jnp.array([2.0]), } assert actual_grads.keys() == expected_grads.keys() for name in expected_grads: assert_allclose( actual_grads[name], expected_grads[name], rtol=precision, atol=precision )
def Adam(kwargs): step_size = kwargs.pop('lr') b1, b2 = kwargs.pop('betas', (0.9, 0.999)) eps = kwargs.pop('eps', 1.e-8) return optim.Adam(step_size=step_size, b1=b1, b2=b2, eps=eps)
def Adam(kwargs): step_size = kwargs.pop("lr") b1, b2 = kwargs.pop("betas", (0.9, 0.999)) eps = kwargs.pop("eps", 1.0e-8) return optim.Adam(step_size=step_size, b1=b1, b2=b2, eps=eps)
plt.show() sum_blbr = post["bl"] + post["br"] fig, ax = plt.subplots() az.plot_kde(sum_blbr, label="sum of bl and br", ax=ax) plt.title(method) pml.savefig(f'multicollinear_sum_post_{method}.pdf') plt.show() # Laplace fit m6_1 = AutoLaplaceApproximation(model) svi = SVI( model, m6_1, optim.Adam(0.1), Trace_ELBO(), leg_left=df.leg_left.values, leg_right=df.leg_right.values, height=df.height.values, br_positive=False ) svi_run = svi.run(random.PRNGKey(0), 2000) p6_1 = svi_run.params losses = svi_run.losses post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,)) analyze_post(post_laplace, 'laplace') # MCMC fit
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal(dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), params, sample_shape=(args.num_samples,))['x'].copy() transform = guide.get_transform(params) _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model) transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform) transformed_constrain_fn = lambda x: constrain_fn(transform(x)) # noqa: E731 print("\nStart NeuTra HMC...") nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) init_params = np.zeros(guide.latent_size) mcmc.run(random.PRNGKey(3), init_params=init_params) mcmc.print_summary() zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) print_summary(tree_map(lambda x: x[None, ...], samples)) samples = samples['x'].copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000,)) guide_trans_samples = vmap(transformed_constrain_fn)(guide_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")
def main(args): rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( (X_train, ), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test, ), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=args.sigma, clipping_threshold=args.clip_threshold, d=args.dimensions, num_obs_total=args.num_samples) rng, svi_init_rng, batchifier_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=batchifier_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) q = args.batch_size / args.num_samples eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs) print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format( eps, args.sigma, args.clip_threshold, args.delta, q)) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready( ) # todo: blocking on loss will probabyl ignore rest of optimization t_end = time.time() if (i % (args.num_epochs // 10) == 0): rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches) print( "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start)) params = svi.get_params(svi_state) mu_loc = params['mu_loc'] mu_std = jnp.exp(params['mu_std_log']) print("### expected: {}".format(mu_true)) print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = analytical_solution(X_train) print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = ml_estimate(X_train) print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
def main(args): rng = PRNGKey(123) rng, toy_data_rng = jax.random.split(rng) train_data, test_data, true_params = create_toy_data( toy_data_rng, args.num_samples, args.dimensions ) train_init, train_fetch = subsample_batchify_data(train_data, batch_size=args.batch_size) test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) # note(lumip): value for c currently completely made up # value for dp_scale completely made up currently. svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples ) rng, svi_init_rng, data_fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=data_fetch_rng) sample_batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *sample_batch) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) batch_X, batch_Y = batch svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch, rng): params = svi.get_params(svi_state) def body_fn(i, val): loss_sum, acc_sum = val batch = test_fetch(i, batchifier_state) batch_X, batch_Y = batch loss = svi.evaluate(svi_state, batch_X, batch_Y) loss_sum += loss / (args.num_samples * num_batch) acc_rng = jax.random.fold_in(rng, i) acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1) acc_sum += acc / num_batch return loss_sum, acc_sum return lax.fori_loop(0, num_batch, body_fn, (0., 0.)) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng) svi_state, train_loss = epoch_train( svi_state, train_batchifier_state, num_train_batches ) train_loss.block_until_ready() # todo: blocking on loss will probabyl ignore rest of optimization t_end = time.time() if (i % (args.num_epochs//10)) == 0: rng, test_rng, test_fetch_rng = random.split(rng, 3) num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng) test_loss, test_acc = eval_test( svi_state, test_batchifier_state, num_test_batches, test_rng ) print("Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)".format( i, test_loss, test_acc, train_loss, t_end - t_start )) # parameters for logistic regression may be scaled arbitrarily. normalize # w (and scale intercept accordingly) for comparison w_true = normalize(true_params[0]) scale_true = jnp.linalg.norm(true_params[0]) intercept_true = true_params[1] / scale_true params = svi.get_params(svi_state) w_post = normalize(params['w_loc']) scale_post = jnp.linalg.norm(params['w_loc']) intercept_post = params['intercept_loc'] / scale_post print("w_loc: {}\nexpected: {}\nerror: {}".format(w_post, w_true, jnp.linalg.norm(w_post-w_true))) print("w_std: {}".format(jnp.exp(params['w_std_log']))) print("") print("intercept_loc: {}\nexpected: {}\nerror: {}".format(intercept_post, intercept_true, jnp.abs(intercept_post-intercept_true))) print("intercept_std: {}".format(jnp.exp(params['intercept_std_log']))) print("") X_test, y_test = test_data rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3) # for evaluation accuracy with true parameters, we scale them to the same # scale as the found posterior. (gives better results than normalized # parameters (probably due to numerical instabilities)) acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10) acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10) print("avg accuracy on test set: with true parameters: {} ; with found posterior: {}\n".format(acc_true, acc_post))
# Model def model(M, A, D=None): a = numpyro.sample("a", dist.Normal(0, 0.2)) bM = numpyro.sample("bM", dist.Normal(0, 0.5)) bA = numpyro.sample("bA", dist.Normal(0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1)) mu = numpyro.deterministic("mu", a + bM * M + bA * A) numpyro.sample("D", dist.Normal(mu, sigma), obs=D) m5_3 = AutoLaplaceApproximation(model) svi = SVI(model, m5_3, optim.Adam(1), Trace_ELBO(), M=d.M.values, A=d.A.values, D=d.D.values) p5_3, losses = svi.run(random.PRNGKey(0), 1000) post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, )) # Posterior param_names = {'a', 'bA', 'bM', 'sigma'} for p in param_names: print(f'posterior for {p}') print_summary(post[p], 0.95, False) # PPC