def main(args): encoder_nn = encoder(args.hidden_dim, args.z_dim) decoder_nn = decoder(args.hidden_dim, 28 * 28) adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim) rng_key = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3) sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0]) svi_state = svi.init(rng_key_init, sample_batch) @jit def epoch_train(svi_state, rng_key): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0., svi_state)) @jit def eval_test(svi_state, rng_key): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.) loss = loss / num_test return loss def reconstruct_img(epoch, rng_key): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') rng_key_binarize, rng_key_sample = random.split(rng_key) test_sample = binarize(rng_key_binarize, img) params = svi.get_params(svi_state) z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng_key_sample) img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split( rng_key, 4) t_start = time.time() num_train, train_idx = train_init() _, svi_state = epoch_train(svi_state, rng_key_train) rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3) num_test, test_idx = test_init() test_loss = eval_test(svi_state, rng_key_test) reconstruct_img(i, rng_key_reconstruct) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def test_covtype_data_load(): _, fetch = load_dataset(COVTYPE, shuffle=False) x, y = fetch() assert jnp.shape(x) == (581012, 54) assert jnp.shape(y) == (581012, )
def fetch_aa_dihedrals(aa): _, fetch = load_dataset(NINE_MERS, split=aa) return jnp.stack(fetch())
dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]), scale=jnp.array([0.5, 0.05, 0.5, 0.05]))) # integrate dz/dt, the result will have shape N x 2 z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500) # measurement errors, we expect that measured hare has larger error than measured lynx sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2]))) # measured populations (in log scale) numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y) import pandas as pd device = 'cpu' numpyro.set_platform(device) _, fetch = load_dataset(LYNXHARE, shuffle=False) num_warmup = 1000 num_chains = 1 num_samples = 1000 numpyro.set_host_device_count(num_chains) year, data = fetch() # data is in hare -> lynx order df_data = pd.DataFrame(columns=['hare', 'lynx']) df_data['hare'] = data[:, 0] df_data['lynx'] = data[:, 1] df_data.index = year # use dense_mass for better mixing rate mcmc = MCMC(
def test_baseball_data_load(): init, fetch = load_dataset(BASEBALL, split='train', shuffle=False) num_batches, idx = init() dataset = fetch(0, idx) assert np.shape(dataset[0]) == (18, 2) assert np.shape(dataset[1]) == (18, )
def test_sp500_data_load(): _, fetch = load_dataset(SP500, split='train', shuffle=False) date, value = fetch() assert np.shape(date) == np.shape(date) == (2427, )
def main(args): encoder_init, encode = encoder(args.hidden_dim, args.z_dim) decoder_init, decode = decoder(args.hidden_dim, 28 * 28) opt_init, opt_update = optimizers.adam(args.learning_rate) svi_init, svi_update, svi_eval = svi(model, guide, elbo, opt_init, opt_update, encode=encode, decode=decode, z_dim=args.z_dim) svi_update = jit(svi_update) rng = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() _, encoder_params = encoder_init((args.batch_size, 28 * 28)) _, decoder_params = decoder_init((args.batch_size, args.z_dim)) params = {'encoder': encoder_params, 'decoder': decoder_params} rng, sample_batch = binarize(rng, train_fetch(0, train_idx)[0]) opt_state = svi_init(rng, (sample_batch, ), (sample_batch, ), params) rng, = random.split(rng, 1) @jit def epoch_train(opt_state, rng): def body_fn(i, val): loss_sum, opt_state, rng = val rng, batch = binarize(rng, train_fetch(i, train_idx)[0]) loss, opt_state, rng = svi_update( i, opt_state, rng, (batch, ), (batch, ), ) loss_sum += loss return loss_sum, opt_state, rng return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng)) @jit def eval_test(opt_state, rng): def body_fun(i, val): loss_sum, rng = val rng, = random.split(rng, 1) rng, batch = binarize(rng, test_fetch(i, test_idx)[0]) loss = svi_eval(opt_state, rng, (batch, ), (batch, )) / len(batch) loss_sum += loss return loss_sum, rng loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng)) loss = loss / num_test return loss def reconstruct_img(epoch): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') _, test_sample = binarize(rng, img) params = optimizers.get_params(opt_state) z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1])) z = dist.norm(z_mean, z_var).rvs(random_state=rng) img_loc = decode(params['decoder'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): t_start = time.time() num_train, train_idx = train_init() _, opt_state, rng = epoch_train(opt_state, rng) rng, rng_test = random.split(rng, 2) num_test, test_idx = test_init() test_loss = eval_test(opt_state, rng_test) reconstruct_img(i) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
# non-centered parameterization num_dept = len(onp.unique(dept)) z = numpyro.sample('z', dist.Normal(np.zeros((num_dept, 2)), 1)) v = np.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male if admit is None: # we use a Delta site to record probs for predictive distribution probs = expit(logits) numpyro.sample('probs', dist.Delta(probs), obs=probs) numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) _, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False) dept, male, applications, admit = fetch_train() rng_key, rng_key_predict = random.split(random.PRNGKey(1)) kernel = NUTS(glmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, dept, male, applications, admit) zs = mcmc.get_samples() pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs'] fig, ax = plt.subplots(1, 1)
def main(args): encoder_init, encode = encoder(args.hidden_dim, args.z_dim) decoder_init, decode = decoder(args.hidden_dim, 28 * 28) opt_init, opt_update, get_params = optimizers.adam(args.learning_rate) svi_init, svi_update, svi_eval = svi(model, guide, elbo, opt_init, opt_update, get_params, encode=encode, decode=decode, z_dim=args.z_dim) rng = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng, rng_enc, rng_dec, rng_binarize, rng_init = random.split(rng, 5) _, encoder_params = encoder_init(rng_enc, (args.batch_size, 28 * 28)) _, decoder_params = decoder_init(rng_dec, (args.batch_size, args.z_dim)) params = {'encoder': encoder_params, 'decoder': decoder_params} sample_batch = binarize(rng_binarize, train_fetch(0, train_idx)[0]) opt_state, constrain_fn = svi_init(rng_init, (sample_batch, ), (sample_batch, ), params) @jit def epoch_train(opt_state, rng): def body_fn(i, val): loss_sum, opt_state, rng = val rng, rng_binarize = random.split(rng) batch = binarize(rng_binarize, train_fetch(i, train_idx)[0]) # TODO: we will want to merge (i, rng, opt_state) into `svi_state` # Here the index `i` is reseted after each epoch, which causes no # problem for static learning rate, but it is not a right way for # scheduled learning rate. loss, opt_state, rng = svi_update( i, rng, opt_state, (batch, ), (batch, ), ) loss_sum += loss return loss_sum, opt_state, rng return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng)) @jit def eval_test(opt_state, rng): def body_fun(i, val): loss_sum, rng = val rng, rng_binarize, rng_eval = random.split(rng, 3) batch = binarize(rng_binarize, test_fetch(i, test_idx)[0]) loss = svi_eval(rng_eval, opt_state, (batch, ), (batch, )) / len(batch) loss_sum += loss return loss_sum, rng loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng)) loss = loss / num_test return loss def reconstruct_img(epoch, rng): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') rng_binarize, rng_sample = random.split(rng) test_sample = binarize(rng_binarize, img) params = get_params(opt_state) z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng_sample) img_loc = decode(params['decoder'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): t_start = time.time() num_train, train_idx = train_init() _, opt_state, rng = epoch_train(opt_state, rng) rng, rng_test, rng_reconstruct = random.split(rng, 3) num_test, test_idx = test_init() test_loss = eval_test(opt_state, rng_test) reconstruct_img(i, rng_reconstruct) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))