def test_beta_bernoulli(): data = np.array([1.0] * 8 + [0.0] * 2) def model(data): f = sample("beta", dist.Beta(1., 1.)) sample("obs", dist.Bernoulli(f), obs=data) def guide(): alpha_q = param("alpha_q", 1.0, constraint=constraints.positive) beta_q = param("beta_q", 1.0, constraint=constraints.positive) sample("beta", dist.Beta(alpha_q, beta_q)) opt_init, opt_update, get_params = optimizers.adam(0.05) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng_init, rng_train = random.split(random.PRNGKey(1)) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, )) return opt_state_, rng_ opt_state, _ = fori_loop(0, 300, body_fn, (opt_state, rng_train)) params = constrain_fn(get_params(opt_state)) assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
def test_beta_bernoulli(auto_class): data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = sample('beta', dist.Beta(np.ones(2), np.ones(2))) sample('obs', dist.Bernoulli(f), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = auto_class(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, ), guide_args=(data, )) return opt_state_, rng_ opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train)) true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), opt_state, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['beta'], 0), true_coefs, atol=0.04)
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = sample('alpha', dist.Uniform(0, 1)) loc = sample('loc', dist.Uniform(0, alpha)) sample('obs', dist.Normal(loc, 0.1), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = AutoDiagonalNormal(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, ), guide_args=(data, )) return opt_state_, rng_ opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train)) median = guide.median(opt_state) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(opt_state, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. opt_init, opt_update, get_params = optimizers.adam(args.learning_rate) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng = PRNGKey(0) opt_state = svi_init(rng, model_args=(data,)) # Training loop rng, = random.split(rng, 1) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, opt_state_, rng_, model_args=(data,)) return opt_state_, rng_ opt_state, _ = lax.fori_loop(0, args.num_steps, body_fn, (opt_state, rng)) # Report the final values of the variational parameters # in the guide after training. params = get_params(opt_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def test_beta_bernoulli(auto_class, rtol): data = np.array([1.0] * 8 + [0.0] * 2) def model(data): f = sample('beta', dist.Beta(1., 1.)) sample('obs', dist.Bernoulli(f), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.08) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = auto_class(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, ), guide_args=(data, )) return opt_state_, rng_ opt_state, _ = lax.fori_loop(0, 300, body_fn, (opt_state, rng_train)) median = guide.median(opt_state) assert_allclose(median['beta'], 0.8, rtol=rtol)
def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def actual_model(data): alpha = sample('alpha', dist.Uniform(0, 1)) loc = sample('loc', dist.Uniform(0, alpha)) sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = sample('alpha', dist.Uniform(0, 1)) loc = sample('loc', dist.Uniform(0, 1)) * alpha sample('obs', dist.Normal(loc, 0.1), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = AutoDiagonalNormal(rng_guide, actual_model, get_params) svi_init, _, svi_eval = svi(actual_model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, (data, ), (data, )) actual_params = get_params(opt_state) actual_base_values = constrain_fn(actual_params) actual_values = guide.median(opt_state) actual_loss = svi_eval(random.PRNGKey(1), opt_state, (data, ), (data, )) guide = AutoDiagonalNormal(rng_guide, expected_model, get_params) svi_init, _, svi_eval = svi(expected_model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, (data, ), (data, )) expected_params = get_params(opt_state) expected_base_values = constrain_fn(expected_params) expected_values = guide.median(opt_state) expected_loss = svi_eval(random.PRNGKey(1), opt_state, (data, ), (data, )) check_eq(actual_params, expected_params) check_eq(actual_base_values, expected_base_values) assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc'], expected_values['alpha'] * expected_values['loc']) assert_allclose(actual_loss, expected_loss)
def test_logistic_regression(auto_class): N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) logits = np.sum(coefs * data, axis=-1) return sample('obs', dist.Bernoulli(logits=logits), obs=labels) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = auto_class(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, labels), guide_args=(data, labels)) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, labels), guide_args=(data, labels)) return opt_state_, rng_ opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train)) if auto_class is not AutoIAFNormal: median = guide.median(opt_state) assert_allclose(median['coefs'], true_coefs, rtol=0.1) # test .quantile method median = guide.quantiles(opt_state, [0.2, 0.5]) assert_allclose(median['coefs'][1], true_coefs, rtol=0.1) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), opt_state, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
def test_dynamic_constraints(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): # NB: model's constraints will play no effect loc = param('loc', 0., constraint=constraints.interval(0, 0.5)) sample('obs', dist.Normal(loc, 0.1), obs=data) def guide(): alpha = param('alpha', 0.5, constraint=constraints.unit_interval) param('loc', 0, constraint=constraints.interval(0, alpha)) opt_init, opt_update, get_params = optimizers.adam(0.05) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng_init, rng_train = random.split(random.PRNGKey(1)) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, )) return opt_state_, rng_ opt_state, rng = fori_loop(0, 300, body_fn, (opt_state, rng_train)) params = get_param(opt_state, model, guide, get_params, constrain_fn, rng, guide_args=()) assert_allclose(params['loc'], true_coef, atol=0.05)
def main(args): encoder_init, encode = encoder(args.hidden_dim, args.z_dim) decoder_init, decode = decoder(args.hidden_dim, 28 * 28) opt_init, opt_update, get_params = optimizers.adam(args.learning_rate) svi_init, svi_update, svi_eval = svi(model, guide, elbo, opt_init, opt_update, get_params, encode=encode, decode=decode, z_dim=args.z_dim) svi_update = jit(svi_update) rng = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng, rng_enc, rng_dec = random.split(rng, 3) _, encoder_params = encoder_init(rng_enc, (args.batch_size, 28 * 28)) _, decoder_params = decoder_init(rng_dec, (args.batch_size, args.z_dim)) params = {'encoder': encoder_params, 'decoder': decoder_params} rng, sample_batch = binarize(rng, train_fetch(0, train_idx)[0]) opt_state, constrain_fn = svi_init(rng, (sample_batch,), (sample_batch,), params) rng, = random.split(rng, 1) @jit def epoch_train(opt_state, rng): def body_fn(i, val): loss_sum, opt_state, rng = val rng, batch = binarize(rng, train_fetch(i, train_idx)[0]) loss, opt_state, rng = svi_update(i, rng, opt_state, (batch,), (batch,),) loss_sum += loss return loss_sum, opt_state, rng return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng)) @jit def eval_test(opt_state, rng): def body_fun(i, val): loss_sum, rng = val rng, = random.split(rng, 1) rng, batch = binarize(rng, test_fetch(i, test_idx)[0]) loss = svi_eval(rng, opt_state, (batch,), (batch,)) / len(batch) loss_sum += loss return loss_sum, rng loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng)) loss = loss / num_test return loss def reconstruct_img(epoch): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') _, test_sample = binarize(rng, img) params = get_params(opt_state) z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng) img_loc = decode(params['decoder'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): t_start = time.time() num_train, train_idx = train_init() _, opt_state, rng = epoch_train(opt_state, rng) rng, rng_test = random.split(rng, 2) num_test, test_idx = test_init() test_loss = eval_test(opt_state, rng_test) reconstruct_img(i) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def main(args): jax_config.update('jax_platform_name', args.device) print("Start vanilla HMC...") vanilla_samples = mcmc(args.num_warmup, args.num_samples, init_params=np.array([2., 0.]), potential_fn=dual_moon_pe, progbar=True) opt_init, opt_update, get_params = optimizers.adam(0.001) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = AutoIAFNormal(rng_guide, dual_moon_model, get_params, hidden_dims=[args.num_hidden]) svi_init, svi_update, _ = svi(dual_moon_model, guide, elbo, opt_init, opt_update, get_params) opt_state, _ = svi_init(rng_init) def body_fn(val, i): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_) return (opt_state_, rng_), loss print("Start training guide...") (last_state, _), losses = lax.scan(body_fn, (opt_state, rng_train), np.arange(args.num_iters)) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), last_state, sample_shape=(args.num_samples,)) transform = guide.get_transform(last_state) unpack_fn = guide.unpack_latent _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 init_params = np.zeros(guide.latent_size) print("\nStart NeuTra HMC...") zs = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=transformed_potential_fn) print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) summary(tree_map(lambda x: x[None, ...], samples)) # make plots # IAF guide samples (for plotting) iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,)) iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6) ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()
# -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('../..')) # HACK: This is to ensure that local functions are documented by sphinx. from numpyro.mcmc import hmc # noqa: E402 from numpyro.svi import svi # noqa: E402 os.environ['SPHINX_BUILD'] = '1' hmc(None, None) svi(None, None, None, None, None, None) # -- Project information ----------------------------------------------------- project = u'Numpyro' copyright = u'2019, Uber Technologies, Inc' author = u'Uber AI Labs' # The short X.Y version version = u'0.0' # The full version, including alpha/beta/rc tags release = u'0.0' # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here.