예제 #1
0
파일: vae.py 프로젝트: cnheider/numpyro
def main(args):
    encoder_nn = encoder(args.hidden_dim, args.z_dim)
    decoder_nn = decoder(args.hidden_dim, 28 * 28)
    adam = optim.Adam(args.learning_rate)
    svi = SVI(model,
              guide,
              adam,
              ELBO(),
              hidden_dim=args.hidden_dim,
              z_dim=args.z_dim)
    rng_key = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST,
                                           batch_size=args.batch_size,
                                           split='train')
    test_init, test_fetch = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test')
    num_train, train_idx = train_init()
    rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
    sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
    svi_state = svi.init(rng_key_init, sample_batch)

    @jit
    def epoch_train(svi_state, rng_key):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0., svi_state))

    @jit
    def eval_test(svi_state, rng_key):
        def body_fun(i, loss_sum):
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.)
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch, rng_key):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR,
                                'original_epoch={}.png'.format(epoch)),
                   img,
                   cmap='gray')
        rng_key_binarize, rng_key_sample = random.split(rng_key)
        test_sample = binarize(rng_key_binarize, img)
        params = svi.get_params(svi_state)
        z_mean, z_var = encoder_nn[1](params['encoder$params'],
                                      test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
        img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR,
                                'recons_epoch={}.png'.format(epoch)),
                   img_loc,
                   cmap='gray')

    for i in range(args.num_epochs):
        rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(
            rng_key, 4)
        t_start = time.time()
        num_train, train_idx = train_init()
        _, svi_state = epoch_train(svi_state, rng_key_train)
        rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
        num_test, test_idx = test_init()
        test_loss = eval_test(svi_state, rng_key_test)
        reconstruct_img(i, rng_key_reconstruct)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss,
                                                       time.time() - t_start))
예제 #2
0
def test_covtype_data_load():
    _, fetch = load_dataset(COVTYPE, shuffle=False)
    x, y = fetch()
    assert jnp.shape(x) == (581012, 54)
    assert jnp.shape(y) == (581012, )
예제 #3
0
def fetch_aa_dihedrals(aa):
    _, fetch = load_dataset(NINE_MERS, split=aa)
    return jnp.stack(fetch())
예제 #4
0
        dist.TruncatedNormal(low=0.,
                             loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)


import pandas as pd
device = 'cpu'
numpyro.set_platform(device)

_, fetch = load_dataset(LYNXHARE, shuffle=False)

num_warmup = 1000
num_chains = 1
num_samples = 1000
numpyro.set_host_device_count(num_chains)

year, data = fetch()  # data is in hare -> lynx order

df_data = pd.DataFrame(columns=['hare', 'lynx'])
df_data['hare'] = data[:, 0]
df_data['lynx'] = data[:, 1]

df_data.index = year
# use dense_mass for better mixing rate
mcmc = MCMC(
예제 #5
0
def test_baseball_data_load():
    init, fetch = load_dataset(BASEBALL, split='train', shuffle=False)
    num_batches, idx = init()
    dataset = fetch(0, idx)
    assert np.shape(dataset[0]) == (18, 2)
    assert np.shape(dataset[1]) == (18, )
예제 #6
0
def test_sp500_data_load():
    _, fetch = load_dataset(SP500, split='train', shuffle=False)
    date, value = fetch()
    assert np.shape(date) == np.shape(date) == (2427, )
예제 #7
0
파일: vae.py 프로젝트: ColCarroll/numpyro
def main(args):
    encoder_init, encode = encoder(args.hidden_dim, args.z_dim)
    decoder_init, decode = decoder(args.hidden_dim, 28 * 28)
    opt_init, opt_update = optimizers.adam(args.learning_rate)
    svi_init, svi_update, svi_eval = svi(model,
                                         guide,
                                         elbo,
                                         opt_init,
                                         opt_update,
                                         encode=encode,
                                         decode=decode,
                                         z_dim=args.z_dim)
    svi_update = jit(svi_update)
    rng = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST,
                                           batch_size=args.batch_size,
                                           split='train')
    test_init, test_fetch = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test')
    num_train, train_idx = train_init()
    _, encoder_params = encoder_init((args.batch_size, 28 * 28))
    _, decoder_params = decoder_init((args.batch_size, args.z_dim))
    params = {'encoder': encoder_params, 'decoder': decoder_params}
    rng, sample_batch = binarize(rng, train_fetch(0, train_idx)[0])
    opt_state = svi_init(rng, (sample_batch, ), (sample_batch, ), params)
    rng, = random.split(rng, 1)

    @jit
    def epoch_train(opt_state, rng):
        def body_fn(i, val):
            loss_sum, opt_state, rng = val
            rng, batch = binarize(rng, train_fetch(i, train_idx)[0])
            loss, opt_state, rng = svi_update(
                i,
                opt_state,
                rng,
                (batch, ),
                (batch, ),
            )
            loss_sum += loss
            return loss_sum, opt_state, rng

        return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng))

    @jit
    def eval_test(opt_state, rng):
        def body_fun(i, val):
            loss_sum, rng = val
            rng, = random.split(rng, 1)
            rng, batch = binarize(rng, test_fetch(i, test_idx)[0])
            loss = svi_eval(opt_state, rng, (batch, ), (batch, )) / len(batch)
            loss_sum += loss
            return loss_sum, rng

        loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng))
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR,
                                'original_epoch={}.png'.format(epoch)),
                   img,
                   cmap='gray')
        _, test_sample = binarize(rng, img)
        params = optimizers.get_params(opt_state)
        z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1]))
        z = dist.norm(z_mean, z_var).rvs(random_state=rng)
        img_loc = decode(params['decoder'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR,
                                'recons_epoch={}.png'.format(epoch)),
                   img_loc,
                   cmap='gray')

    for i in range(args.num_epochs):
        t_start = time.time()
        num_train, train_idx = train_init()
        _, opt_state, rng = epoch_train(opt_state, rng)
        rng, rng_test = random.split(rng, 2)
        num_test, test_idx = test_init()
        test_loss = eval_test(opt_state, rng_test)
        reconstruct_img(i)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss,
                                                       time.time() - t_start))
예제 #8
0
    # non-centered parameterization
    num_dept = len(onp.unique(dept))
    z = numpyro.sample('z', dist.Normal(np.zeros((num_dept, 2)), 1))
    v = np.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)




_, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False)
dept, male, applications, admit = fetch_train()
rng_key, rng_key_predict = random.split(random.PRNGKey(1))


kernel = NUTS(glmm)
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains,
            progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)

mcmc.run(rng_key, dept, male, applications, admit)

zs = mcmc.get_samples()

pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs']

fig, ax = plt.subplots(1, 1)
예제 #9
0
def main(args):
    encoder_init, encode = encoder(args.hidden_dim, args.z_dim)
    decoder_init, decode = decoder(args.hidden_dim, 28 * 28)
    opt_init, opt_update, get_params = optimizers.adam(args.learning_rate)
    svi_init, svi_update, svi_eval = svi(model,
                                         guide,
                                         elbo,
                                         opt_init,
                                         opt_update,
                                         get_params,
                                         encode=encode,
                                         decode=decode,
                                         z_dim=args.z_dim)
    rng = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST,
                                           batch_size=args.batch_size,
                                           split='train')
    test_init, test_fetch = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test')
    num_train, train_idx = train_init()
    rng, rng_enc, rng_dec, rng_binarize, rng_init = random.split(rng, 5)
    _, encoder_params = encoder_init(rng_enc, (args.batch_size, 28 * 28))
    _, decoder_params = decoder_init(rng_dec, (args.batch_size, args.z_dim))
    params = {'encoder': encoder_params, 'decoder': decoder_params}
    sample_batch = binarize(rng_binarize, train_fetch(0, train_idx)[0])
    opt_state, constrain_fn = svi_init(rng_init, (sample_batch, ),
                                       (sample_batch, ), params)

    @jit
    def epoch_train(opt_state, rng):
        def body_fn(i, val):
            loss_sum, opt_state, rng = val
            rng, rng_binarize = random.split(rng)
            batch = binarize(rng_binarize, train_fetch(i, train_idx)[0])
            # TODO: we will want to merge (i, rng, opt_state) into `svi_state`
            # Here the index `i` is reseted after each epoch, which causes no
            # problem for static learning rate, but it is not a right way for
            # scheduled learning rate.
            loss, opt_state, rng = svi_update(
                i,
                rng,
                opt_state,
                (batch, ),
                (batch, ),
            )
            loss_sum += loss
            return loss_sum, opt_state, rng

        return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng))

    @jit
    def eval_test(opt_state, rng):
        def body_fun(i, val):
            loss_sum, rng = val
            rng, rng_binarize, rng_eval = random.split(rng, 3)
            batch = binarize(rng_binarize, test_fetch(i, test_idx)[0])
            loss = svi_eval(rng_eval, opt_state, (batch, ),
                            (batch, )) / len(batch)
            loss_sum += loss
            return loss_sum, rng

        loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng))
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch, rng):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR,
                                'original_epoch={}.png'.format(epoch)),
                   img,
                   cmap='gray')
        rng_binarize, rng_sample = random.split(rng)
        test_sample = binarize(rng_binarize, img)
        params = get_params(opt_state)
        z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng_sample)
        img_loc = decode(params['decoder'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR,
                                'recons_epoch={}.png'.format(epoch)),
                   img_loc,
                   cmap='gray')

    for i in range(args.num_epochs):
        t_start = time.time()
        num_train, train_idx = train_init()
        _, opt_state, rng = epoch_train(opt_state, rng)
        rng, rng_test, rng_reconstruct = random.split(rng, 3)
        num_test, test_idx = test_init()
        test_loss = eval_test(opt_state, rng_test)
        reconstruct_img(i, rng_reconstruct)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss,
                                                       time.time() - t_start))