Exemple #1
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()["x"].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), Trace_ELBO())

    print("Start training guide...")
    svi_result = svi.run(random.PRNGKey(1), args.num_iters)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(2),
        svi_result.params,
        sample_shape=(args.num_samples, ))["x"].copy()

    print("\nStart NeuTra HMC...")
    neutra = NeuTraReparam(guide, svi_result.params)
    neutra_model = neutra.reparam(dual_moon_model)
    nuts_kernel = NUTS(neutra_model)
    mcmc = MCMC(
        nuts_kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(random.PRNGKey(3))
    mcmc.print_summary()
    zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
    print("Transform samples into unwarped space...")
    samples = neutra.transform_sample(zs)
    print_summary(samples)
    zs = zs.reshape(-1, 2)
    samples = samples["x"].reshape(-1, 2).copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(jnp.zeros(2),
                                     1.0).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = neutra.transform_sample(guide_base_samples)["x"]

    x1 = jnp.linspace(-3, 3, 100)
    x2 = jnp.linspace(-3, 3, 100)
    X1, X2 = jnp.meshgrid(x1, x2)
    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(svi_result.losses[1000:])
    ax1.set_title("Autoguide training loss\n(after 1000 steps)")

    ax2.contourf(X1, X2, P, cmap="OrRd")
    sns.kdeplot(x=guide_samples[:, 0],
                y=guide_samples[:, 1],
                n_levels=30,
                ax=ax2)
    ax2.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel="x0",
        ylabel="x1",
        title="Posterior using\nAutoBNAFNormal guide",
    )

    sns.scatterplot(
        x=guide_base_samples[:, 0],
        y=guide_base_samples[:, 1],
        ax=ax3,
        hue=guide_trans_samples[:, 0] < 0.0,
    )
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel="x0",
        ylabel="x1",
        title="AutoBNAFNormal base samples\n(True=left moon; False=right moon)",
    )

    ax4.contourf(X1, X2, P, cmap="OrRd")
    sns.kdeplot(x=vanilla_samples[:, 0],
                y=vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             "bo-",
             alpha=0.5)
    ax4.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel="x0",
        ylabel="x1",
        title="Posterior using\nvanilla HMC sampler",
    )

    sns.scatterplot(
        x=zs[:, 0],
        y=zs[:, 1],
        ax=ax5,
        hue=samples[:, 0] < 0.0,
        s=30,
        alpha=0.5,
        edgecolor="none",
    )
    ax5.set(
        xlim=[-5, 5],
        ylim=[-5, 5],
        xlabel="x0",
        ylabel="x1",
        title="Samples from the\nwarped posterior - p(z)",
    )

    ax6.contourf(X1, X2, P, cmap="OrRd")
    sns.kdeplot(x=samples[:, 0], y=samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], "bo-", alpha=0.2)
    ax6.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel="x0",
        ylabel="x1",
        title="Posterior using\nNeuTra HMC sampler",
    )

    plt.savefig("neutra.pdf")
Exemple #2
0
def main(args):
    encoder_nn = encoder(args.hidden_dim, args.z_dim)
    decoder_nn = decoder(args.hidden_dim, 28 * 28)
    adam = optim.Adam(args.learning_rate)
    svi = SVI(
        model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim
    )
    rng_key = PRNGKey(0)
    train_init, train_fetch = load_dataset(
        MNIST, batch_size=args.batch_size, split="train"
    )
    test_init, test_fetch = load_dataset(
        MNIST, batch_size=args.batch_size, split="test"
    )
    num_train, train_idx = train_init()
    rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
    sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
    svi_state = svi.init(rng_key_init, sample_batch)

    @jit
    def epoch_train(svi_state, rng_key, train_idx):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0.0, svi_state))

    @jit
    def eval_test(svi_state, rng_key, test_idx):
        def body_fun(i, loss_sum):
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.0)
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch, rng_key):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(
            os.path.join(RESULTS_DIR, "original_epoch={}.png".format(epoch)),
            img,
            cmap="gray",
        )
        rng_key_binarize, rng_key_sample = random.split(rng_key)
        test_sample = binarize(rng_key_binarize, img)
        params = svi.get_params(svi_state)
        z_mean, z_var = encoder_nn[1](
            params["encoder$params"], test_sample.reshape([1, -1])
        )
        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
        img_loc = decoder_nn[1](params["decoder$params"], z).reshape([28, 28])
        plt.imsave(
            os.path.join(RESULTS_DIR, "recons_epoch={}.png".format(epoch)),
            img_loc,
            cmap="gray",
        )

    for i in range(args.num_epochs):
        rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(
            rng_key, 4
        )
        t_start = time.time()
        num_train, train_idx = train_init()
        _, svi_state = epoch_train(svi_state, rng_key_train, train_idx)
        rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
        num_test, test_idx = test_init()
        test_loss = eval_test(svi_state, rng_key_test, test_idx)
        reconstruct_img(i, rng_key_reconstruct)
        print(
            "Epoch {}: loss = {} ({:.2f} s.)".format(
                i, test_loss, time.time() - t_start
            )
        )
Exemple #3
0
        return Trace_ELBO().loss(random.PRNGKey(0), {}, model, guide, x)

    def renyi_loss_fn(x):
        return RenyiELBO(alpha=alpha,
                         num_particles=10).loss(random.PRNGKey(0), {}, model,
                                                guide, x)

    elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.0)
    renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.0)
    assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
    assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
@pytest.mark.parametrize(
    "optimizer", [optim.Adam(0.05), optimizers.adam(0.05)])
def test_beta_bernoulli(elbo, optimizer):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    svi = SVI(model, guide, optimizer, elbo)
Exemple #4
0
def test_tracegraph_gaussian_chain(num_latents, num_steps, step_size, atol,
                                   difficulty):
    loc0 = 0.2
    data = jnp.array([-0.1, 0.03, 0.2, 0.1])
    n_data = data.shape[0]
    sum_data = data.sum()
    N = num_latents
    lambdas = [1.5 * (k + 1) / N for k in range(N + 1)]
    lambdas = list(map(lambda x: jnp.array([x]), lambdas))
    lambda_tilde_posts = [lambdas[0]]
    for k in range(1, N):
        lambda_tilde_k = (lambdas[k] * lambda_tilde_posts[k - 1]) / (
            lambdas[k] + lambda_tilde_posts[k - 1])
        lambda_tilde_posts.append(lambda_tilde_k)
    lambda_posts = [
        None
    ]  # this is never used (just a way of shifting the indexing by 1)
    for k in range(1, N):
        lambda_k = lambdas[k] + lambda_tilde_posts[k - 1]
        lambda_posts.append(lambda_k)
    lambda_N_post = (n_data * lambdas[N]) + lambda_tilde_posts[N - 1]
    lambda_posts.append(lambda_N_post)
    target_kappas = [None]
    target_kappas.extend([lambdas[k] / lambda_posts[k] for k in range(1, N)])
    target_mus = [None]
    target_mus.extend([
        loc0 * lambda_tilde_posts[k - 1] / lambda_posts[k]
        for k in range(1, N)
    ])
    target_loc_N = (sum_data * lambdas[N] / lambda_N_post +
                    loc0 * lambda_tilde_posts[N - 1] / lambda_N_post)
    target_mus.append(target_loc_N)
    np.random.seed(0)
    while True:
        mask = np.random.binomial(1, 0.3, (N, ))
        if mask.sum() < 0.4 * N and mask.sum() > 0.5:
            which_nodes_reparam = mask
            break

    class FakeNormal(dist.Normal):
        reparametrized_params = []

    def model(difficulty=0.0):
        next_mean = loc0
        for k in range(1, N + 1):
            latent_dist = dist.Normal(next_mean,
                                      jnp.power(lambdas[k - 1], -0.5))
            loc_latent = numpyro.sample("loc_latent_{}".format(k), latent_dist)
            next_mean = loc_latent

        loc_N = next_mean
        with numpyro.plate("data", data.shape[0]):
            numpyro.sample("obs",
                           dist.Normal(loc_N, jnp.power(lambdas[N], -0.5)),
                           obs=data)
        return loc_N

    def guide(difficulty=0.0):
        previous_sample = None
        for k in reversed(range(1, N + 1)):
            loc_q = numpyro.param(
                f"loc_q_{k}",
                lambda key: target_mus[k] + difficulty *
                (0.1 * random.normal(key) - 0.53),
            )
            log_sig_q = numpyro.param(
                f"log_sig_q_{k}",
                lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty *
                (0.1 * random.normal(key) - 0.53),
            )
            sig_q = jnp.exp(log_sig_q)
            kappa_q = None
            if k != N:
                kappa_q = numpyro.param(
                    "kappa_q_%d" % k,
                    lambda key: target_kappas[k] + difficulty *
                    (0.1 * random.normal(key) - 0.53),
                )
            mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q
            node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False
            Normal = dist.Normal if node_flagged else FakeNormal
            loc_latent = numpyro.sample(f"loc_latent_{k}",
                                        Normal(mean_function, sig_q))
            previous_sample = loc_latent
        return previous_sample

    adam = optim.Adam(step_size=step_size, b1=0.95, b2=0.999)
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
    svi_result = svi.run(jax.random.PRNGKey(0),
                         num_steps,
                         difficulty=difficulty)

    kappa_errors, log_sig_errors, loc_errors = [], [], []
    for k in range(1, N + 1):
        if k != N:
            kappa_error = jnp.sum(
                jnp.power(svi_result.params[f"kappa_q_{k}"] - target_kappas[k],
                          2))
            kappa_errors.append(kappa_error)

        loc_errors.append(
            jnp.sum(
                jnp.power(svi_result.params[f"loc_q_{k}"] - target_mus[k], 2)))
        log_sig_error = jnp.sum(
            jnp.power(
                svi_result.params[f"log_sig_q_{k}"] +
                0.5 * jnp.log(lambda_posts[k]), 2))
        log_sig_errors.append(log_sig_error)

    max_errors = (np.max(loc_errors), np.max(log_sig_errors),
                  np.max(kappa_errors))

    for i in range(3):
        assert_allclose(max_errors[i], 0, atol=atol)
Exemple #5
0
def main(args):
    # loading data
    (train_init, train_fetch_plain), num_samples = load_dataset(
        MNIST,
        batch_size=args.batch_size,
        split='train',
        batchifier=subsample_batchify_data)
    (test_init,
     test_fetch_plain), _ = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test',
                                         batchifier=split_batchify_data)

    def binarize_fetch(fetch_fn):
        @jit
        def fetch_binarized(batch_nr, batchifier_state, binarize_rng):
            batch = fetch_fn(batch_nr, batchifier_state)
            return binarize(binarize_rng, batch[0]), batch[1]

        return fetch_binarized

    train_fetch = binarize_fetch(train_fetch_plain)
    test_fetch = binarize_fetch(test_fetch_plain)

    # setting up optimizer
    optimizer = optimizers.Adam(args.learning_rate)

    # the minibatch environment in our model scales individual
    # records' contributions to the loss up by num_samples.
    # This can cause numerical instabilities so we scale down
    # the loss by 1/num_samples here.
    sc_model = scale(model, scale=1 / num_samples)
    sc_guide = scale(guide, scale=1 / num_samples)

    if args.no_dp:
        svi = SVI(sc_model,
                  sc_guide,
                  optimizer,
                  ELBO(),
                  num_obs_total=num_samples,
                  z_dim=args.z_dim,
                  hidden_dim=args.hidden_dim)
    else:
        q = args.batch_size / num_samples
        target_eps = args.epsilon
        dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps,
                                                 delta=1 / num_samples,
                                                 q=q,
                                                 num_iter=int(1 / q) *
                                                 args.num_epochs,
                                                 force_smaller=True)
        print(
            f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})"
        )

        svi = DPSVI(sc_model,
                    sc_guide,
                    optimizer,
                    ELBO(),
                    dp_scale=dp_scale,
                    clipping_threshold=10.,
                    num_obs_total=num_samples,
                    z_dim=args.z_dim,
                    hidden_dim=args.hidden_dim)

    # preparing random number generators and initializing svi
    rng = PRNGKey(0)
    rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0]
    svi_state = svi.init(svi_init_rng, sample_batch)

    # functions for training tasks
    @jit
    def epoch_train(svi_state, batchifier_state, num_batches, rng):
        """Trains one epoch

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: overall training loss over the epoch
        """
        def body_fn(i, val):
            svi_state, loss = val
            binarize_rng = random.fold_in(rng, i)
            batch = train_fetch(i, batchifier_state, binarize_rng)[0]
            svi_state, batch_loss = svi.update(svi_state, batch)
            loss += batch_loss / num_batches
            return svi_state, loss

        svi_state, loss = lax.fori_loop(0, num_batches, body_fn,
                                        (svi_state, 0.))
        return svi_state, loss

    @jit
    def eval_test(svi_state, batchifier_state, num_batches, rng):
        """Evaluates current model state on test data.

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: loss over the test split
        """
        def body_fn(i, loss_sum):
            binarize_rng = random.fold_in(rng, i)
            batch = test_fetch(i, batchifier_state, binarize_rng)[0]
            batch_loss = svi.evaluate(svi_state, batch)
            loss_sum += batch_loss / num_batches
            return loss_sum

        return lax.fori_loop(0, num_batches, body_fn, 0.)

    def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng):
        """Reconstructs an image for the given epoch

        Obtains a sample from the testing data set and passes it through the
        VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and
        the original input as 'epoch_{epoch}_original.png' in folder '.results'.

        :param epoch: Number of the current epoch
        :param num_epochs: Number of total epochs
        :param opt_state: Current state of the optimizer
        :param rng: rng key
        """
        assert (num_epochs > 0)
        img = test_fetch_plain(0, batchifier_state)[0][0]
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_original.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img,
                   cmap='gray')
        rng, rng_binarize = random.split(rng, 2)
        test_sample = binarize(rng_binarize, img)
        test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample)))
        params = svi.get_params(svi_state)

        samples = sample_multi_posterior_predictive(
            rng, 10, model,
            (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])),
            guide, (test_sample, args.z_dim, args.hidden_dim), params)

        img_loc = samples['obs'][0].reshape([28, 28])
        avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28])
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img_loc,
                   cmap='gray')
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   avg_img_loc,
                   cmap='gray')

    # main training loop
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng, train_rng = random.split(rng, 3)
        num_train_batches, train_batchifier_state, = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches, train_rng)

        rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4)
        num_test_batches, test_batchifier_state = test_init(
            rng_key=test_fetch_rng)
        test_loss = eval_test(svi_state, test_batchifier_state,
                              num_test_batches, test_rng)

        reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state,
                        recons_rng)
        print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
            i, test_loss, train_loss,
            time.time() - t_start))
Exemple #6
0
def test_subsample_gradient(scale, subsample):
    data = jnp.array([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale

    def model(subsample):
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size) as ind:
                x = data[ind]
                z = numpyro.sample("z", dist.Normal(0, 1))
                numpyro.sample("x", dist.Normal(z, 1), obs=x)

    def guide(subsample):
        scale = numpyro.param("scale", 1.0)
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size):
                loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0)
                numpyro.sample("z", dist.Normal(loc, scale))

    if scale != 1.0:
        model = handlers.scale(model, scale=scale)
        guide = handlers.scale(guide, scale=scale)

    num_particles = 50000
    optimizer = optim.Adam(0.1)
    elbo = Trace_ELBO(num_particles=num_particles)
    svi = SVI(model, guide, optimizer, loss=elbo)
    svi_state = svi.init(random.PRNGKey(0), None)
    params = svi.optim.get_params(svi_state.optim_state)
    normalizer = 2 if subsample else 1
    if subsample_size == 1:
        subsample = jnp.array([0])
        loss1, grads1 = value_and_grad(
            lambda x: svi.loss.loss(
                svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample
            )
        )(params)
        subsample = jnp.array([1])
        loss2, grads2 = value_and_grad(
            lambda x: svi.loss.loss(
                svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample
            )
        )(params)
        grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2)
        loss = loss1 + loss2
    else:
        subsample = jnp.array([0, 1])
        loss, grads = value_and_grad(
            lambda x: svi.loss.loss(
                svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample
            )
        )(params)

    actual_loss = loss / normalizer
    expected_loss, _ = value_and_grad(
        lambda x: svi.loss.loss(
            svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None
        )
    )(params)
    assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision)

    actual_grads = {name: grad / normalizer for name, grad in grads.items()}
    expected_grads = {
        "loc": scale * jnp.array([0.5, -2.0]),
        "scale": scale * jnp.array([2.0]),
    }
    assert actual_grads.keys() == expected_grads.keys()
    for name in expected_grads:
        assert_allclose(
            actual_grads[name], expected_grads[name], rtol=precision, atol=precision
        )
Exemple #7
0
def Adam(kwargs):
    step_size = kwargs.pop('lr')
    b1, b2 = kwargs.pop('betas', (0.9, 0.999))
    eps = kwargs.pop('eps', 1.e-8)
    return optim.Adam(step_size=step_size, b1=b1, b2=b2, eps=eps)
Exemple #8
0
def Adam(kwargs):
    step_size = kwargs.pop("lr")
    b1, b2 = kwargs.pop("betas", (0.9, 0.999))
    eps = kwargs.pop("eps", 1.0e-8)
    return optim.Adam(step_size=step_size, b1=b1, b2=b2, eps=eps)
Exemple #9
0
    plt.show()

    sum_blbr = post["bl"] + post["br"]
    fig, ax = plt.subplots()
    az.plot_kde(sum_blbr, label="sum of bl and br", ax=ax)
    plt.title(method)
    pml.savefig(f'multicollinear_sum_post_{method}.pdf')
    plt.show()

# Laplace fit

m6_1 = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    m6_1,
    optim.Adam(0.1),
    Trace_ELBO(),
    leg_left=df.leg_left.values,
    leg_right=df.leg_right.values,
    height=df.height.values,
    br_positive=False
)
svi_run = svi.run(random.PRNGKey(0), 2000)
p6_1 = svi_run.params
losses = svi_run.losses
post_laplace = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,))

analyze_post(post_laplace, 'laplace')


# MCMC fit
Exemple #10
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
                                           sample_shape=(args.num_samples,))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x))  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000,))
    guide_trans_samples = vmap(transformed_constrain_fn)(guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
    ax4.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0.,
                    s=30, alpha=0.5, edgecolor="none")
    ax5.set(xlim=[-5, 5], ylim=[-5, 5],
            xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
Exemple #11
0
def main(args):
    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)
    X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples,
                                               args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        (X_train, ), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test, ),
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=args.sigma,
                clipping_threshold=args.clip_threshold,
                d=args.dimensions,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, batchifier_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    q = args.batch_size / args.num_samples
    eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs)
    print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format(
        eps, args.sigma, args.clip_threshold, args.delta, q))

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(svi_state, *batch)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)

            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready(
        )  # todo: blocking on loss will probabyl ignore rest of optimization
        t_end = time.time()

        if (i % (args.num_epochs // 10) == 0):
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss = eval_test(svi_state, test_batchifier_state,
                                  num_test_batches)

            print(
                "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start))

    params = svi.get_params(svi_state)
    mu_loc = params['mu_loc']
    mu_std = jnp.exp(params['mu_std_log'])
    print("### expected: {}".format(mu_true))
    print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = analytical_solution(X_train)
    print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = ml_estimate(X_train)
    print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
Exemple #12
0
def main(args):
    rng = PRNGKey(123)
    rng, toy_data_rng = jax.random.split(rng)

    train_data, test_data, true_params = create_toy_data(
        toy_data_rng, args.num_samples, args.dimensions
    )

    train_init, train_fetch = subsample_batchify_data(train_data, batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    # note(lumip): value for c currently completely made up
    #   value for dp_scale completely made up currently.
    svi = DPSVI(model, guide, optimizer, ELBO(),
        dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples
    )

    rng, svi_init_rng, data_fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=data_fetch_rng)
    sample_batch = train_fetch(0, batchifier_state)

    svi_state = svi.init(svi_init_rng, *sample_batch)

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch, rng):
        params = svi.get_params(svi_state)

        def body_fn(i, val):
            loss_sum, acc_sum = val

            batch = test_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            loss = svi.evaluate(svi_state, batch_X, batch_Y)
            loss_sum += loss / (args.num_samples * num_batch)

            acc_rng = jax.random.fold_in(rng, i)
            acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1)
            acc_sum += acc / num_batch

            return loss_sum, acc_sum

        return lax.fori_loop(0, num_batch, body_fn, (0., 0.))

	## Train model
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(
            svi_state, train_batchifier_state, num_train_batches
        )
        train_loss.block_until_ready() # todo: blocking on loss will probabyl ignore rest of optimization
        t_end = time.time()

        if (i % (args.num_epochs//10)) == 0:
            rng, test_rng, test_fetch_rng = random.split(rng, 3)
            num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng)
            test_loss, test_acc = eval_test(
                svi_state, test_batchifier_state, num_test_batches, test_rng
            )
            print("Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)".format(
                i, test_loss, test_acc, train_loss, t_end - t_start
            ))

    # parameters for logistic regression may be scaled arbitrarily. normalize
    #   w (and scale intercept accordingly) for comparison
    w_true = normalize(true_params[0])
    scale_true = jnp.linalg.norm(true_params[0])
    intercept_true = true_params[1] / scale_true

    params = svi.get_params(svi_state)
    w_post = normalize(params['w_loc'])
    scale_post = jnp.linalg.norm(params['w_loc'])
    intercept_post = params['intercept_loc'] / scale_post

    print("w_loc: {}\nexpected: {}\nerror: {}".format(w_post, w_true, jnp.linalg.norm(w_post-w_true)))
    print("w_std: {}".format(jnp.exp(params['w_std_log'])))
    print("")
    print("intercept_loc: {}\nexpected: {}\nerror: {}".format(intercept_post, intercept_true, jnp.abs(intercept_post-intercept_true)))
    print("intercept_std: {}".format(jnp.exp(params['intercept_std_log'])))
    print("")

    X_test, y_test = test_data
    rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3)
    # for evaluation accuracy with true parameters, we scale them to the same
    #   scale as the found posterior. (gives better results than normalized
    #   parameters (probably due to numerical instabilities))
    acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10)
    acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10)

    print("avg accuracy on test set:  with true parameters: {} ; with found posterior: {}\n".format(acc_true, acc_post))
# Model


def model(M, A, D=None):
    a = numpyro.sample("a", dist.Normal(0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0, 0.5))
    bA = numpyro.sample("bA", dist.Normal(0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = numpyro.deterministic("mu", a + bM * M + bA * A)
    numpyro.sample("D", dist.Normal(mu, sigma), obs=D)


m5_3 = AutoLaplaceApproximation(model)
svi = SVI(model,
          m5_3,
          optim.Adam(1),
          Trace_ELBO(),
          M=d.M.values,
          A=d.A.values,
          D=d.D.values)
p5_3, losses = svi.run(random.PRNGKey(0), 1000)
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, ))

# Posterior

param_names = {'a', 'bA', 'bM', 'sigma'}
for p in param_names:
    print(f'posterior for {p}')
    print_summary(post[p], 0.95, False)

# PPC