Ejemplo n.º 1
0
# create summary writer and write to log directory
timestamp = cutils.get_timestamp()
log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
writer = SummaryWriter(log_dir=log_dir)
filename = os.path.join(log_dir, 'config.json')
with open(filename, 'w') as file:
    json.dump(vars(args), file)

tbar = tqdm(range(args.num_training_steps))
for step in tbar:
    flow.train()
    optimizer.zero_grad()
    scheduler.step(step)

    batch = toy_data.inf_train_gen(args.dataset_name,
                                   batch_size=args.batch_size)
    batch = torch.from_numpy(batch).type(torch.float32).to(device)

    _, log_density = flow.log_prob(batch)
    loss = -torch.mean(log_density)
    loss.backward()
    if args.grad_norm_clip_value is not None:
        clip_grad_norm_(flow.parameters(), args.grad_norm_clip_value)
    optimizer.step()

    if (step + 1) % args.monitor_interval == 0:
        s = 'Loss: {:.4f}'.format(loss.item())
        tbar.set_description(s)

        summaries = {'loss': loss.detach()}
        for summary, value in summaries.items():
Ejemplo n.º 2
0
import tf_gmm_tools
from toy_data import inf_train_gen

DIMENSIONS = 2
COMPONENTS = 8
NUM_POINTS = 10000

TRAINING_STEPS = 1000
TOLERANCE = 10e-6

# PREPARING DATA

# generating DATA_POINTS points from a GMM with COMPONENTS components
#data, true_means, true_covariances, true_weights, responsibilities = tf_gmm_tools.generate_gmm_data(
#    NUM_POINTS, COMPONENTS, DIMENSIONS, seed=10, diagonal=True)
data = inf_train_gen("checkerboard", rng=None, batch_size=20000)
print(data)

# BUILDING COMPUTATIONAL GRAPH

# model inputs: data points and prior parameters
input = tf.placeholder(tf.float64, [None, DIMENSIONS])
alpha = tf.placeholder_with_default(tf.cast(1.0, tf.float64), [])
beta = tf.placeholder_with_default(tf.cast(1.0, tf.float64), [])

# constants: ln(2 * PI) * D
ln2piD = tf.constant(np.log(2 * np.pi) * DIMENSIONS, dtype=tf.float64)

# computing input statistics
dim_means = tf.reduce_mean(input, 0)
dim_distances = tf.squared_difference(input, tf.expand_dims(dim_means, 0))
Ejemplo n.º 3
0
def main(args):
    utils.makedirs(args.save_dir)
    cn_sgld_dir = "{}/{}".format(args.save_dir, "condition_number")
    utils.makedirs(cn_sgld_dir)
    h_sgld_dir = "{}/{}".format(args.save_dir, "largest_sv")
    utils.makedirs(h_sgld_dir)
    l_sgld_dir = "{}/{}".format(args.save_dir, "smallest_sv")
    utils.makedirs(l_sgld_dir)
    data_sgld_dir = "{}/{}".format(args.save_dir, "data_sgld")
    utils.makedirs(data_sgld_dir)
    gen_sgld_dir = "{}/{}".format(args.save_dir, "generator_sgld")
    utils.makedirs(gen_sgld_dir)
    z_sgld_dir = "{}/{}".format(args.save_dir, "z_only_sgld")
    utils.makedirs(z_sgld_dir)

    data_sgld_chain_dir = "{}/{}_chain".format(args.save_dir,
                                               "data_sgld_chain")
    utils.makedirs(data_sgld_chain_dir)
    gen_sgld_chain_dir = "{}/{}_chain".format(args.save_dir,
                                              "generator_sgld_chain")
    utils.makedirs(gen_sgld_chain_dir)
    z_sgld_chain_dir = "{}/{}_chain".format(args.save_dir, "z_only_sgld_chain")
    utils.makedirs(z_sgld_chain_dir)
    logp_net, g = get_models(args)

    with open("{}/args.txt".format(args.save_dir), 'w') as f:
        json.dump(args.__dict__, f)

    e_optimizer = torch.optim.Adam(logp_net.parameters(),
                                   lr=args.lr,
                                   betas=[args.beta1, args.beta2],
                                   weight_decay=args.weight_decay)
    g_optimizer = torch.optim.Adam(g.parameters(),
                                   lr=args.glr,
                                   betas=[args.beta1, args.beta2],
                                   weight_decay=args.weight_decay)

    class P(nn.Module):
        def __init__(self):
            super().__init__()
            self.logsigma = nn.Parameter(torch.zeros(args.noise_dim, ) - 2)

    post_logsigma = P()
    p_optimizer = torch.optim.Adam(post_logsigma.parameters(),
                                   lr=args.lr * 10,
                                   betas=[args.beta1, args.beta2])

    train_loader, test_loader, plot = get_data(args)

    def sample_q(n, requires_grad=False):
        h = torch.randn((n, args.noise_dim)).to(device)
        if requires_grad:
            h.requires_grad_()
        x_mu = g.generator(h)
        x = x_mu + torch.randn_like(x_mu) * g.logsigma.exp()
        return x, h

    def logq_joint(x, h, return_mu=False):
        logph = distributions.Normal(0, 1).log_prob(h).sum(1)
        gmu = g.generator(h)
        px_given_h = distributions.Normal(gmu, g.logsigma.exp())
        logpx_given_h = px_given_h.log_prob(x).flatten(start_dim=1).sum(1)
        if return_mu:
            return logpx_given_h + logph, gmu
        else:
            return logpx_given_h + logph

    g.train()
    g.to(device)
    logp_net.train()
    logp_net.to(device)
    post_logsigma.to(device)

    itr = 0
    stepsize = 1. / args.noise_dim
    sgld_lr = 1. / args.noise_dim
    sgld_lr_z = 1. / args.noise_dim
    sgld_lr_zne = 1. / args.noise_dim
    for epoch in range(args.n_epochs):
        for x_d, _ in train_loader:
            if args.dataset in TOY_DSETS:
                x_d = toy_data.inf_train_gen(args.dataset,
                                             batch_size=args.batch_size)
                x_d = torch.from_numpy(x_d).float().to(device)
            else:
                x_d = x_d.to(device)

            # sample from q(x, h)
            x_g, h_g = sample_q(args.batch_size)
            x_g_ref = x_g
            # ebm obj
            ld = logp_net(x_d).squeeze()
            lg_detach = logp_net(x_g_ref.detach()).squeeze()
            logp_obj = (ld - lg_detach).mean()

            e_loss = -logp_obj + (ld**2).mean() * args.p_control
            if itr % args.e_iters == 0:
                e_optimizer.zero_grad()
                e_loss.backward()
                e_optimizer.step()

            x_g, h_g = sample_q(args.batch_size)
            # gen obj
            lg = logp_net(x_g).squeeze()

            if args.gp == 0:
                if args.my_single_sample:
                    logq = logq_joint(x_g.detach(), h_g.detach())
                    # mine
                    logq_obj = lg.mean() - args.ent_weight * logq.mean()
                    g_error_entropy = -logq.mean()
                elif args.adji_single_sample:
                    # adji
                    mean_output_summed = g.generator(h_g)
                    c = ((x_g - mean_output_summed) /
                         g.logsigma.exp()**2).detach()
                    g_error_entropy = torch.mul(c, x_g).mean(0).sum()
                    logq_obj = lg.mean() + args.ent_weight * g_error_entropy
                elif args.sv_bound:
                    v, t = find_extreme_singular_vectors(
                        g.generator, h_g, args.niters, args.v_norm)
                    log_sv = log_sigular_values_sum_bound(
                        g.generator, h_g, v, args.v_norm)
                    logpx = -log_sv.mean() - distributions.Normal(
                        0, g.logsigma.exp()).entropy() * args.data_dim
                    g_error_entropy = logpx
                    logq_obj = lg.mean() - args.ent_weight * logpx
                elif args.pg_inf:
                    post = distributions.Normal(h_g.detach(),
                                                post_logsigma.logsigma.exp())
                    h_g_post = post.rsample()
                    joint, mean_output_summed = logq_joint(x_g.detach(),
                                                           h_g_post,
                                                           return_mu=True)
                    post_ent = post.entropy().sum(1)
                    elbo = joint + post_ent
                    post_loss = -elbo.mean()
                    p_optimizer.zero_grad()
                    post_loss.backward()
                    p_optimizer.step()

                    c = ((x_g - mean_output_summed) /
                         g.logsigma.exp()**2).detach()
                    g_error_entropy = torch.mul(c, x_g).mean(0).sum()
                    logq_obj = lg.mean() + args.ent_weight * g_error_entropy

                elif args.ub_inf:
                    post = distributions.Normal(h_g.detach(),
                                                post_logsigma.logsigma.exp())
                    joint = logq_joint(x_g.detach(), h_g.detach())
                    post_logp = post.log_prob(h_g.detach()).sum(1)
                    uelbo = joint - post_logp
                    post_loss = uelbo.mean()
                    p_optimizer.zero_grad()
                    post_loss.backward(retain_graph=True)
                    p_optimizer.step()

                    #c = ((x_g - mean_output_summed) / g.logsigma.exp() ** 2).detach()
                    g_error_entropy = uelbo.mean(
                    )  #torch.mul(c, x_g).mean(0).sum()
                    logq_obj = lg.mean() - args.ent_weight * uelbo.mean()

                else:
                    num_samples_posterior = 2
                    h_given_x, acceptRate, stepsize = hmc.get_gen_posterior_samples(
                        g.generator,
                        x_g.detach(),
                        h_g.clone(),
                        g.logsigma.exp().detach(),
                        burn_in=2,
                        num_samples_posterior=num_samples_posterior,
                        leapfrog_steps=5,
                        stepsize=stepsize,
                        flag_adapt=1,
                        hmc_learning_rate=.02,
                        hmc_opt_accept=.67)

                    mean_output_summed = torch.zeros_like(x_g)
                    mean_output = g.generator(h_given_x)
                    # for h in [h_g, h_given_x]:
                    for cnt in range(num_samples_posterior):
                        mean_output_summed = mean_output_summed + mean_output[
                            cnt * args.batch_size:(cnt + 1) * args.batch_size]
                    mean_output_summed = mean_output_summed / num_samples_posterior

                    c = ((x_g - mean_output_summed) /
                         g.logsigma.exp()**2).detach()
                    g_error_entropy = torch.mul(c, x_g).mean(0).sum()
                    logq_obj = lg.mean() + args.ent_weight * g_error_entropy
            else:
                x_g, h_g = sample_q(args.batch_size, requires_grad=True)
                if args.brute_force:
                    jac = torch.zeros((x_g.size(0), x_g.size(1), h_g.size(1)))
                    j = torch.autograd.grad(x_g[:, 0].sum(),
                                            h_g,
                                            retain_graph=True)[0]
                    jac[:, 0, :] = j
                    j = torch.autograd.grad(x_g[:, 1].sum(),
                                            h_g,
                                            retain_graph=True)[0]
                    jac[:, 1, :] = j
                    u, s, v = torch.svd(jac)
                    logs = s.log()
                    logpx = 0 - logs.sum(1)
                    g_error_entropy = logpx.mean()
                    logq_obj = lg.mean() - args.ent_weight * logpx.mean()

                else:
                    eps = torch.randn_like(x_g)
                    epsJ = torch.autograd.grad(x_g,
                                               h_g,
                                               grad_outputs=eps,
                                               retain_graph=True)[0]
                    #eps2 = torch.randn_like(x_g)
                    #epsJ2 = torch.autograd.grad(x_g, h_g, grad_outputs=eps2, retain_graph=True)[0]
                    epsJtJeps = (epsJ * epsJ).sum(1)
                    g_error_entropy = ((epsJtJeps - args.gp)**2).mean()
                    logq_obj = lg.mean() - args.ent_weight * g_error_entropy

            g_loss = -logq_obj
            if itr % args.e_iters == 0:
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            if args.clamp:
                g.logsigma.data.clamp_(np.log(.01), np.log(.0101))
            else:
                g.logsigma.data.clamp_(np.log(.01), np.log(.3))

            if itr % args.print_every == 0:
                print(
                    "({}) | log p obj = {:.4f}, log q obj = {:.4f}, sigma = {:.4f} | "
                    "log p(x_d) = {:.4f}, log p(x_m) = {:.4f}, ent = {:.4f} | "
                    "sgld_lr = {}, sgld_lr_z = {}, sgld_lr_zne = {} | stepsize = {}"
                    .format(itr, logp_obj.item(), logq_obj.item(),
                            g.logsigma.exp().item(),
                            ld.mean().item(),
                            lg_detach.mean().item(), g_error_entropy.item(),
                            sgld_lr, sgld_lr_z, sgld_lr_zne, stepsize))
                if args.pg_inf or args.ub_inf:
                    print("    log sigma = {}, {}".format(
                        post_logsigma.logsigma.exp().mean().item(),
                        post_logsigma.logsigma.exp().std().item()))

            if itr % args.viz_every == 0:
                if args.dataset in TOY_DSETS:
                    plt.clf()
                    xg = x_g_ref.detach().cpu().numpy()
                    xd = x_d.cpu().numpy()

                    ax = plt.subplot(1, 4, 1, aspect="equal", title='refined')
                    ax.scatter(xg[:, 0], xg[:, 1], s=1)

                    ax = plt.subplot(1, 4, 2, aspect="equal", title='data')
                    ax.scatter(xd[:, 0], xd[:, 1], s=1)

                    ax = plt.subplot(1, 4, 3, aspect="equal")
                    logp_net.cpu()
                    utils.plt_flow_density(lambda x: logp_net(x),
                                           ax,
                                           low=x_d.min().item(),
                                           high=x_d.max().item())
                    plt.savefig("/{}/{}.png".format(args.save_dir, itr))
                    logp_net.to(device)

                    ax = plt.subplot(1, 4, 4, aspect="equal")
                    logp_net.cpu()
                    utils.plt_flow_density(lambda x: logp_net(x),
                                           ax,
                                           low=x_d.min().item(),
                                           high=x_d.max().item(),
                                           exp=False)
                    plt.savefig("/{}/{}.png".format(args.save_dir, itr))
                    logp_net.to(device)

                    x_g, h_g = sample_q(args.batch_size, requires_grad=True)
                    jac = torch.zeros((x_g.size(0), x_g.size(1), h_g.size(1)))

                    j = torch.autograd.grad(x_g[:, 0].sum(),
                                            h_g,
                                            retain_graph=True)[0]
                    jac[:, 0, :] = j
                    j = torch.autograd.grad(x_g[:, 1].sum(),
                                            h_g,
                                            retain_graph=True)[0]
                    jac[:, 1, :] = j
                    u, s, v = torch.svd(jac)

                    s1, s2 = s[:, 0].detach(), s[:, 1].detach()
                    plt.clf()
                    plt.hist(s1.numpy(), alpha=.75)
                    plt.hist(s2.numpy(), alpha=.75)
                    plt.savefig("{}/{}_svd.png".format(args.save_dir, itr))

                    plt.clf()
                    plt.hist(s1.log().numpy(), alpha=.75)
                    plt.hist(s2.log().numpy(), alpha=.75)
                    plt.savefig("{}/{}_log_svd.png".format(args.save_dir, itr))
                else:
                    x_g, h_g = sample_q(args.batch_size, requires_grad=True)
                    J = brute_force_jac(x_g, h_g)
                    c, h, l = condition_number(J)
                    plt.clf()
                    plt.hist(c.numpy())
                    plt.savefig("{}/cn_{}.png".format(cn_sgld_dir, itr))
                    plt.clf()
                    plt.hist(h.numpy())
                    plt.savefig("{}/large_s_{}.png".format(h_sgld_dir, itr))
                    plt.clf()
                    plt.hist(l.numpy())
                    plt.savefig("{}/small_s_{}.png".format(l_sgld_dir, itr))
                    plt.clf()

                    plot("{}/{}_init.png".format(data_sgld_dir, itr),
                         x_g.view(x_g.size(0), *args.data_size))
                    #plot("{}/{}_ref.png".format(args.save_dir, itr), x_g_ref.view(x_g.size(0), *args.data_size))
                    # input space sgld
                    x_sgld = x_g.clone()
                    steps = [x_sgld.clone()]
                    accepts = []
                    for k in range(args.sgld_steps):
                        [x_sgld], a = MALA([x_sgld],
                                           lambda x: logp_net(x).squeeze(),
                                           sgld_lr)
                        steps.append(x_sgld.clone())
                        accepts.append(a.item())
                    ar = np.mean(accepts)
                    print("accept rate: {}".format(ar))
                    sgld_lr = sgld_lr + .2 * (ar - .57) * sgld_lr
                    plot("{}/{}_ref.png".format(data_sgld_dir, itr),
                         x_sgld.view(x_g.size(0), *args.data_size))

                    chain = torch.cat([step[0][None] for step in steps], 0)
                    plot("{}/{}.png".format(data_sgld_chain_dir, itr),
                         chain.view(chain.size(0), *args.data_size))

                    # latent space sgld
                    eps_sgld = torch.randn_like(x_g)
                    z_sgld = torch.randn(
                        (eps_sgld.size(0), args.noise_dim)).to(eps_sgld.device)
                    vs = (z_sgld.requires_grad_(), eps_sgld.requires_grad_())
                    steps = [vs]
                    accepts = []
                    gfn = lambda z, e: g.generator(z) + g.logsigma.exp() * e
                    efn = lambda z, e: logp_net(gfn(z, e)).squeeze()
                    x_sgld = gfn(z_sgld, eps_sgld)
                    plot("{}/{}_init.png".format(gen_sgld_dir, itr),
                         x_sgld.view(x_g.size(0), *args.data_size))
                    for k in range(args.sgld_steps):
                        vs, a = MALA(vs, efn, sgld_lr_z)
                        steps.append([v.clone() for v in vs])
                        accepts.append(a.item())
                    ar = np.mean(accepts)
                    print("accept rate: {}".format(ar))
                    sgld_lr_z = sgld_lr_z + .2 * (ar - .57) * sgld_lr_z
                    z_sgld, eps_sgld = steps[-1]
                    x_sgld = gfn(z_sgld, eps_sgld)
                    plot("{}/{}_ref.png".format(gen_sgld_dir, itr),
                         x_sgld.view(x_g.size(0), *args.data_size))

                    z_steps, eps_steps = zip(*steps)
                    z_chain = torch.cat([step[0][None] for step in z_steps], 0)
                    eps_chain = torch.cat(
                        [step[0][None] for step in eps_steps], 0)
                    chain = gfn(z_chain, eps_chain)
                    plot("{}/{}.png".format(gen_sgld_chain_dir, itr),
                         chain.view(chain.size(0), *args.data_size))

                    # latent space sgld no eps
                    z_sgld = torch.randn(
                        (eps_sgld.size(0), args.noise_dim)).to(eps_sgld.device)
                    vs = (z_sgld.requires_grad_(), )
                    steps = [vs]
                    accepts = []
                    gfn = lambda z: g.generator(z)
                    efn = lambda z: logp_net(gfn(z)).squeeze()
                    x_sgld = gfn(z_sgld)
                    plot("{}/{}_init.png".format(z_sgld_dir, itr),
                         x_sgld.view(x_g.size(0), *args.data_size))
                    for k in range(args.sgld_steps):
                        vs, a = MALA(vs, efn, sgld_lr_zne)
                        steps.append([v.clone() for v in vs])
                        accepts.append(a.item())
                    ar = np.mean(accepts)
                    print("accept rate: {}".format(ar))
                    sgld_lr_zne = sgld_lr_zne + .2 * (ar - .57) * sgld_lr_zne
                    z_sgld, = steps[-1]
                    x_sgld = gfn(z_sgld)
                    plot("{}/{}_ref.png".format(z_sgld_dir, itr),
                         x_sgld.view(x_g.size(0), *args.data_size))

                    z_steps = [s[0] for s in steps]
                    z_chain = torch.cat([step[0][None] for step in z_steps], 0)
                    chain = gfn(z_chain)
                    plot("{}/{}.png".format(z_sgld_chain_dir, itr),
                         chain.view(chain.size(0), *args.data_size))

            itr += 1
Ejemplo n.º 4
0
def main(args):
    utils.makedirs(args.save_dir)
    logp_net, g = get_models(args)

    e_optimizer = torch.optim.Adam(logp_net.parameters(), lr=args.lr, betas=[0.5, .999], weight_decay=args.weight_decay)
    g_optimizer = torch.optim.Adam(g.parameters(), lr=args.lr / 1, betas=[0.5, .999], weight_decay=args.weight_decay)

    train_loader, test_loader, plot = get_data(args)



    def sample_q(n):
        h = torch.randn((n, args.noise_dim)).to(device)
        x_mu = g.generator(h)
        x = x_mu + torch.randn_like(x_mu) * g.logsigma.exp()
        return x, h

    def logq_joint(x, h):
        logph = distributions.Normal(0, 1).log_prob(h).sum(1)
        px_given_h = distributions.Normal(g.generator(h), g.logsigma.exp())
        logpx_given_h = px_given_h.log_prob(x).flatten(start_dim=1).sum(1)
        return logpx_given_h + logph

    g.train()
    g.to(device)
    logp_net.train()
    logp_net.to(device)

    itr = 0
    stepsize = 1. / args.noise_dim
    stepsize_x = 1. / args.data_dim
    es = []
    for epoch in range(args.n_epochs):
        for x_d, _ in train_loader:
            if args.dataset in TOY_DSETS:
                x_d = toy_data.inf_train_gen(args.dataset, batch_size=args.batch_size)
                x_d = torch.from_numpy(x_d).float().to(device)
            else:
                x_d = x_d.to(device)

            # sample from q(x, h)
            x_g, h_g = sample_q(args.batch_size)
            x_gc = x_g.clone()
            if args.refine:
                x_g_ref, acceptRate_x, stepsize_x = hmc.get_ebm_samples(logp_net, x_g.detach(),
                                                                        burn_in=2, num_samples_posterior=1,
                                                                        leapfrog_steps=5,
                                                                        stepsize=stepsize_x, flag_adapt=1,
                                                                        hmc_learning_rate=.02, hmc_opt_accept=.67)
            elif args.refine_latent:
                h_g_ref, eps_ref, acceptRate_x, stepsize_x = hmc.get_ebm_latent_samples(logp_net, g.generator,
                                                                                        h_g.detach(), torch.randn_like(x_g).detach(),
                                                                                        g.logsigma.exp().detach(),
                                                                                        burn_in=2, num_samples_posterior=1,
                                                                                        leapfrog_steps=5,
                                                                                        stepsize=stepsize_x, flag_adapt=1,
                                                                                        hmc_learning_rate=.02, hmc_opt_accept=.67)
                x_g_ref = g.generator(h_g_ref) + eps_ref * g.logsigma.exp()
            else:
                x_g_ref = x_g
            # ebm obj
            ld = logp_net(x_d).squeeze()
            lg_detach = logp_net(x_g_ref.detach()).squeeze()
            logp_obj = (ld - lg_detach).mean()

            if args.stagger:
                if lg_detach.mean() > ld.mean() - 2 * ld.std() or itr < 100:
                    e_loss = -logp_obj + (ld ** 2).mean() * args.p_control
                    e_optimizer.zero_grad()
                    e_loss.backward()
                    e_optimizer.step()
                    #print('e')
                    if itr < 100:
                        es.append(1)
                    else:
                        es = es[1:] + [1]
                else:
                    #print('no-e', lg_detach.mean(), ld.mean(), ld.std())
                    es = es[1:] + [0]
                if itr % args.print_every == 0:
                    print('e frac', np.mean(es))
            else:
                e_loss = -logp_obj + (ld ** 2).mean() * args.p_control
                e_optimizer.zero_grad()
                e_loss.backward()
                e_optimizer.step()

            # gen obj
            if args.mode == "reverse_kl":
                x_g, h_g = sample_q(args.batch_size)

                lg = logp_net(x_g).squeeze()
                num_samples_posterior = 2
                h_given_x, acceptRate, stepsize = hmc.get_gen_posterior_samples(
                    g.generator, x_g.detach(), h_g.clone(), g.logsigma.exp().detach(), burn_in=2,
                    num_samples_posterior=num_samples_posterior, leapfrog_steps=5, stepsize=stepsize, flag_adapt=1,
                    hmc_learning_rate=.02, hmc_opt_accept=.67)

                mean_output_summed = torch.zeros_like(x_g)
                mean_output = g.generator(h_given_x)
                # for h in [h_g, h_given_x]:
                for cnt in range(num_samples_posterior):
                    mean_output_summed = mean_output_summed + mean_output[cnt*args.batch_size:(cnt+1)*args.batch_size]
                mean_output_summed = mean_output_summed / num_samples_posterior

                c = ((x_g - mean_output_summed) / g.logsigma.exp() ** 2).detach()
                g_error_entropy = torch.mul(c, x_g).mean(0).sum()
                logq_obj = lg.mean() + g_error_entropy

            elif args.mode == "kl":
                h_given_x, acceptRate, stepsize = hmc.get_gen_posterior_samples(
                    g.generator, x_g_ref.detach(), h_g.clone(), g.logsigma.exp().detach(), burn_in=2,
                    num_samples_posterior=1, leapfrog_steps=5, stepsize=stepsize, flag_adapt=1,
                    hmc_learning_rate=.02, hmc_opt_accept=.67
                )
                logq_obj = logq_joint(x_g_ref.detach(), h_given_x.detach()).mean()
                g_error_entropy = torch.zeros_like(logq_obj)
            else:
                raise ValueError

            g_loss = -logq_obj
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            g.logsigma.data.clamp_(np.log(.01), np.log(.3))

            if itr % args.print_every == 0:
                delta = (x_gc - x_g_ref).flatten(start_dim=1).norm(dim=1).mean()
                print("({}) | log p obj = {:.4f}, log q obj = {:.4f}, sigma = {:.4f} | "
                      "log p(x_d) = {:.4f}, log p(x_m) = {:.4f}, ent = {:.4f} | "
                      "stepsize = {:.4f}, stepsize_x = {} | delta = {:.4f}".format(
                    itr, logp_obj.item(), logq_obj.item(), g.logsigma.exp().item(),
                    ld.mean().item(), lg_detach.mean().item(), g_error_entropy.item(),
                    stepsize, stepsize_x, delta.item()))

            if itr % args.viz_every == 0:
                if args.dataset in TOY_DSETS:
                    plt.clf()
                    xg = x_gc.detach().cpu().numpy()
                    xgr = x_g_ref.detach().cpu().numpy()
                    xd = x_d.cpu().numpy()

                    ax = plt.subplot(1, 5, 1, aspect="equal", title='init')
                    ax.scatter(xg[:, 0], xg[:, 1], s=1)

                    ax = plt.subplot(1, 5, 2, aspect="equal", title='refined')
                    ax.scatter(xgr[:, 0], xgr[:, 1], s=1)

                    ax = plt.subplot(1, 5, 3, aspect="equal", title='data')
                    ax.scatter(xd[:, 0], xd[:, 1], s=1)

                    ax = plt.subplot(1, 5, 4, aspect="equal")
                    logp_net.cpu()
                    utils.plt_flow_density(logp_net, ax, low=x_d.min().item(), high=x_d.max().item())
                    plt.savefig("/{}/{}.png".format(args.save_dir, itr))
                    logp_net.to(device)

                    ax = plt.subplot(1, 5, 5, aspect="equal")
                    logp_net.cpu()
                    utils.plt_flow_density(logp_net, ax, low=x_d.min().item(), high=x_d.max().item(), exp=False)
                    plt.savefig("/{}/{}.png".format(args.save_dir, itr))
                    logp_net.to(device)
                else:
                    plot("{}/{}_init.png".format(args.save_dir, itr), x_gc.view(x_g.size(0), *args.data_size))
                    plot("{}/{}_ref.png".format(args.save_dir, itr), x_g_ref.view(x_g.size(0), *args.data_size))

            itr += 1
Ejemplo n.º 5
0
def main(args):
    utils.makedirs(args.save_dir)
    if args.dataset in TOY_DSETS:
        logp_net = nn.Sequential(
            nn.utils.weight_norm(nn.Linear(args.data_dim, args.h_dim)),
            nn.LeakyReLU(.2),
            nn.utils.weight_norm(nn.Linear(args.h_dim, args.h_dim)),
            nn.LeakyReLU(.2), nn.Linear(args.h_dim, 1))
        logp_fn = lambda x: logp_net(
            x)  # - (x * x).flatten(start_dim=1).sum(1)/10

        class G(nn.Module):
            def __init__(self):
                super().__init__()
                self.generator = nn.Sequential(
                    nn.Linear(args.noise_dim, args.h_dim, bias=False),
                    nn.BatchNorm1d(args.h_dim, affine=True), nn.ReLU(),
                    nn.Linear(args.h_dim, args.h_dim, bias=False),
                    nn.BatchNorm1d(args.h_dim, affine=True), nn.ReLU(),
                    nn.Linear(args.h_dim, args.data_dim))
                self.logsigma = nn.Parameter(torch.zeros(1, ))
    else:
        logp_net = nn.Sequential(
            nn.utils.weight_norm(nn.Linear(args.data_dim, 1000)),
            nn.LeakyReLU(.2), nn.utils.weight_norm(nn.Linear(1000, 500)),
            nn.LeakyReLU(.2), nn.utils.weight_norm(nn.Linear(500, 500)),
            nn.LeakyReLU(.2), nn.utils.weight_norm(nn.Linear(500, 250)),
            nn.LeakyReLU(.2), nn.utils.weight_norm(nn.Linear(250, 250)),
            nn.LeakyReLU(.2), nn.utils.weight_norm(nn.Linear(250, 250)),
            nn.LeakyReLU(.2), nn.Linear(250, 1, bias=False))
        logp_fn = lambda x: logp_net(
            x)  # - (x * x).flatten(start_dim=1).sum(1)/10

        class G(nn.Module):
            def __init__(self):
                super().__init__()
                self.generator = nn.Sequential(
                    nn.Linear(args.noise_dim, 500, bias=False),
                    nn.BatchNorm1d(500, affine=True), nn.Softplus(),
                    nn.Linear(500, 500, bias=False),
                    nn.BatchNorm1d(500, affine=True), nn.Softplus(),
                    nn.Linear(500, args.data_dim), nn.Sigmoid())
                # self.logsigma = nn.Parameter((torch.ones(1,) * (args.sgld_step * 2)**.5).log(), requires_grad=False)
                #self.logsigma = nn.Parameter(torch.zeros(1, ) - 5)
                self.logsigma = nn.Parameter((torch.ones(1, ) * .1).log(),
                                             requires_grad=False)

    g = G()

    params = list(logp_net.parameters()) + list(g.parameters())

    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 betas=[.0, .9],
                                 weight_decay=args.weight_decay)
    #optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)

    train_loader, test_loader = get_data(args)

    sqrt = lambda x: int(torch.sqrt(torch.Tensor([x])))
    plot = lambda p, x: torchvision.utils.save_image(
        torch.clamp(x, 0, 1), p, normalize=False, nrow=sqrt(x.size(0)))

    def sample_q(n):
        h = torch.randn((n, args.noise_dim)).to(device)
        x_mu = g.generator(h)
        x = x_mu + torch.randn_like(x_mu) * g.logsigma.exp()
        return x, h

    def logq_unnorm(x, h):
        logph = distributions.Normal(0, 1).log_prob(h).sum(1)
        px_given_h = distributions.Normal(g.generator(h), g.logsigma.exp())
        logpx_given_h = px_given_h.log_prob(x).sum(1)
        return logpx_given_h + logph

    def refine_q(x_init, h_init, n_steps, sgld_step):
        x_k = torch.clone(x_init).requires_grad_()
        h_k = torch.clone(h_init).requires_grad_()
        sgld_sigma = (2 * sgld_step)**.5
        for k in range(n_steps):
            logp = logp_fn(x_k)[:, 0]
            logq = logq_unnorm(x_k, h_k)
            g_h = torch.autograd.grad(logq.sum(), [h_k], retain_graph=True)[0]

            # sample h tilde ~ q(h|x_k)
            h_tilde = h_k + g_h * sgld_step + torch.randn_like(
                h_k) * sgld_sigma
            h_tilde = h_tilde.detach()

            logq_tilde = logq_unnorm(x_k, h_tilde)
            logq = logq_unnorm(x_k, h_k)
            g_x = torch.autograd.grad(logp.sum() + logq.sum() -
                                      logq_tilde.sum(), [x_k],
                                      retain_graph=True)[0]

            # x update
            x_k = x_k + g_x * sgld_step + torch.randn_like(x_k) * sgld_sigma
            x_k = x_k.detach().requires_grad_()  #.clamp(0, 1)

            # h update
            logq = logq_unnorm(x_k, h_k)
            g_h = torch.autograd.grad(logq.sum(), [h_k], retain_graph=True)[0]
            h_k = h_k + g_h * sgld_step + torch.randn_like(h_k) * sgld_sigma
            h_k = h_k.detach().requires_grad_()

        # return x_k.detach().clamp(0, 1), h_k.detach()
        return x_k.detach(), h_k.detach()

    def refine_q_hmc(x_init, h_init, n_steps, sgld_step, beta=.5):
        x_k = torch.clone(x_init).requires_grad_()
        h_k = torch.clone(h_init).requires_grad_()
        v_x = torch.zeros_like(x_k)
        v_h = torch.zeros_like(h_k)
        sgld_sigma_tilde = (2 * sgld_step)**.5
        sgld_sigma = (2 * beta * sgld_step)**.5
        for k in range(n_steps):
            logp = logp_fn(x_k)[:, 0]
            logq = logq_unnorm(x_k, h_k)
            g_h = torch.autograd.grad(logq.sum(), [h_k], retain_graph=True)[0]

            # sample h tilde ~ q(h|x_k)
            h_tilde = h_k + g_h * sgld_step + torch.randn_like(
                h_k) * sgld_sigma_tilde
            h_tilde = h_tilde.detach()

            logq_tilde = logq_unnorm(x_k, h_tilde)
            logq = logq_unnorm(x_k, h_k)
            g_x = torch.autograd.grad(logp.sum() + logq.sum() -
                                      logq_tilde.sum(), [x_k],
                                      retain_graph=True)[0]
            # x update
            v_x = (v_x * (1 - beta) + sgld_step * g_x +
                   torch.randn_like(x_k) * sgld_sigma).detach()
            x_k = (x_k + v_x).detach().requires_grad_().clamp(0, 1)

            # h update
            logq = logq_unnorm(x_k, h_k)
            g_h = torch.autograd.grad(logq.sum(), [h_k], retain_graph=True)[0]

            v_h = (v_h * (1 - beta) + sgld_step * g_h +
                   torch.randn_like(h_k) * sgld_sigma).detach()
            h_k = (h_k + v_h).detach().requires_grad_()

        return x_k.detach().clamp(0, 1), h_k.detach()

    g.train()
    g.to(device)
    logp_net.to(device)

    itr = 0
    for epoch in range(args.n_epochs):
        for x_d, _ in train_loader:
            optimizer.zero_grad()
            if args.dataset in TOY_DSETS:
                x_d = toy_data.inf_train_gen(args.dataset,
                                             batch_size=args.batch_size)
                x_d = torch.from_numpy(x_d).float().to(device)
            else:
                x_d = x_d.to(device)

            x_init, h_init = sample_q(args.batch_size)
            if args.hmc:
                x, h = refine_q(x_init, h_init, args.n_steps, args.sgld_step)
            else:
                x, h = refine_q_hmc(x_init, h_init, args.n_steps,
                                    args.sgld_step)

            ld = logp_fn(x_d)[:, 0]
            lm = logp_fn(x.detach())[:, 0]
            li = logp_fn(x_init.detach())[:, 0]
            logp_obj = (ld - lm).mean()
            logq_obj = logq_unnorm(x.detach(), h.detach()).mean()

            loss = -(logp_obj + 3 * logq_obj) + args.p_control * (ld**2).mean()
            loss.backward()
            optimizer.step()

            if itr % args.print_every == 0:
                print(
                    "({}) | log p obj = {:.4f}, log q obj = {:.4f}, sigma = {:.4f} | log p(x_d) = {:.4f}, log p(x_m) = {:.4f}, log p(x_i) = {:.4f}"
                    .format(itr, logp_obj.item(), logq_obj.item(),
                            g.logsigma.exp().item(),
                            ld.mean().item(),
                            lm.mean().item(),
                            li.mean().item()))

            if itr % args.viz_every == 0:
                if args.dataset in TOY_DSETS:
                    plt.clf()
                    xm = x.cpu().numpy()
                    xn = x_d.cpu().numpy()
                    xi = x_init.detach().cpu().numpy()
                    ax = plt.subplot(1, 5, 1, aspect="equal", title='refined')
                    ax.scatter(xm[:, 0], xm[:, 1], s=1)

                    ax = plt.subplot(1, 5, 2, aspect="equal", title='initial')
                    ax.scatter(xi[:, 0], xi[:, 1], s=1)

                    ax = plt.subplot(1, 5, 3, aspect="equal", title='data')
                    ax.scatter(xn[:, 0], xn[:, 1], s=1)

                    ax = plt.subplot(1, 5, 4, aspect="equal")
                    logp_net.cpu()
                    utils.plt_flow_density(logp_fn,
                                           ax,
                                           low=x_d.min().item(),
                                           high=x_d.max().item())
                    plt.savefig("/{}/{}.png".format(args.save_dir, itr))
                    logp_net.to(device)

                    ax = plt.subplot(1, 5, 5, aspect="equal")
                    logp_net.cpu()
                    utils.plt_flow_density(logp_fn,
                                           ax,
                                           low=x_d.min().item(),
                                           high=x_d.max().item(),
                                           exp=False)
                    plt.savefig("/{}/{}.png".format(args.save_dir, itr))
                    logp_net.to(device)
                else:
                    plot("{}/init_{}.png".format(args.save_dir, itr),
                         x_init.view(x_init.size(0), *args.data_size))
                    plot("{}/ref_{}.png".format(args.save_dir, itr),
                         x.view(x.size(0), *args.data_size))
                    plot("{}/data_{}.png".format(args.save_dir, itr),
                         x_d.view(x_d.size(0), *args.data_size))

            itr += 1
Ejemplo n.º 6
0
def get_data(args):
    if args.data == "mnist":
        transform = tr.Compose([
            tr.Resize(args.img_size),
            tr.ToTensor(), lambda x: (x > .5).float().view(-1)
        ])
        train_data = torchvision.datasets.MNIST(root="../data",
                                                train=True,
                                                transform=transform,
                                                download=True)
        test_data = torchvision.datasets.MNIST(root="../data",
                                               train=False,
                                               transform=transform,
                                               download=True)
        train_loader = DataLoader(train_data,
                                  args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = DataLoader(test_data,
                                 args.batch_size,
                                 shuffle=True,
                                 drop_last=True)
        sqrt = lambda x: int(torch.sqrt(torch.Tensor([x])))
        plot = lambda p, x: torchvision.utils.save_image(x.view(
            x.size(0), 1, args.img_size, args.img_size),
                                                         p,
                                                         normalize=True,
                                                         nrow=sqrt(x.size(0)))
        encoder = None
        viz = None
    elif args.data in toy_data.TOY_DSETS:
        data = []
        seen = 0
        while seen < args.n_toy_data:
            x = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
            data.append(x)
            seen += x.shape[0]
        data = np.concatenate(data, 0)
        m, M = data.min(), data.max()
        delta = M - m
        buffer = delta / 8.
        encoder = toy_data.Int2Gray(min=m - buffer, max=M + buffer)

        def plot(p, x):
            plt.clf()
            x = x.cpu().detach().numpy()
            x = encoder.decode_batch(x)
            visualize_flow.plt_samples(x, plt.gca())
            plt.savefig(p)

        def viz(p, model):
            plt.clf()
            visualize_flow.plt_flow_density(
                lambda x: model(encoder.encode_batch(x)), plt.gca(), npts=200)
            plt.savefig(p)

        data = torch.from_numpy(data).float()
        e_data = encoder.encode_batch(data)
        y = torch.zeros_like(data[:, 0])
        train_data = TensorDataset(e_data, y)
        train_loader = DataLoader(train_data,
                                  args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = train_loader

    elif args.data_file is not None:
        with open(args.data_file, 'rb') as f:
            x = pickle.load(f)
        x = torch.tensor(x).float()
        train_data = TensorDataset(x)
        train_loader = DataLoader(train_data,
                                  args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = train_loader
        viz = None
        if args.model == "lattice_ising" or args.model == "lattice_ising_2d":
            plot = lambda p, x: torchvision.utils.save_image(
                x.view(x.size(0), 1, args.dim, args.dim),
                p,
                normalize=False,
                nrow=int(x.size(0)**.5))
        elif args.model == "lattice_potts":
            plot = lambda p, x: torchvision.utils.save_image(
                x.view(x.size(0), args.dim, args.dim, 3).transpose(3, 1),
                p,
                normalize=False,
                nrow=int(x.size(0)**.5))
        else:
            plot = lambda p, x: None
    else:
        raise ValueError

    return train_loader, test_loader, plot, viz
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data',
                        choices=[
                            'swissroll', '8gaussians', 'pinwheel', 'circles',
                            'moons', '2spirals', 'checkerboard', 'rings'
                        ],
                        type=str,
                        default='moons')
    parser.add_argument('--niters', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_batch_size', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--critic_weight_decay', type=float, default=0)
    parser.add_argument('--save', type=str, default='/tmp/test_lsd')
    parser.add_argument('--mode',
                        type=str,
                        default="lsd",
                        choices=['lsd', 'sm'])
    parser.add_argument('--viz_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=10000)
    parser.add_argument('--log_freq', type=int, default=100)
    parser.add_argument('--base_dist', action="store_true")
    parser.add_argument('--c_iters', type=int, default=5)
    parser.add_argument('--l2', type=float, default=10.)
    parser.add_argument('--exact_trace', action="store_true")
    parser.add_argument('--n_steps', type=int, default=10)
    args = parser.parse_args()

    # logger
    utils.makedirs(args.save)
    logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                              filepath=os.path.abspath(__file__))
    logger.info(args)

    # fit a gaussian to the training data
    init_size = 1000
    init_batch = sample_data(args, init_size).requires_grad_()
    mu, std = init_batch.mean(0), init_batch.std(0)
    base_dist = distributions.Normal(mu, std)

    # neural netz
    critic = networks.SmallMLP(2, n_out=2)
    net = networks.SmallMLP(2)

    ebm = EBM(net, base_dist if args.base_dist else None)
    ebm.to(device)
    critic.to(device)

    # for sampling
    init_fn = lambda: base_dist.sample_n(args.test_batch_size)
    cov = utils.cov(init_batch)
    sampler = HMCSampler(ebm,
                         .3,
                         5,
                         init_fn,
                         device=device,
                         covariance_matrix=cov)

    logger.info(ebm)
    logger.info(critic)

    # optimizers
    optimizer = optim.Adam(ebm.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(.0, .999))
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.lr,
                                  betas=(.0, .999),
                                  weight_decay=args.critic_weight_decay)

    time_meter = utils.RunningAverageMeter(0.98)
    loss_meter = utils.RunningAverageMeter(0.98)

    ebm.train()
    end = time.time()
    for itr in range(args.niters):

        optimizer.zero_grad()
        critic_optimizer.zero_grad()

        x = sample_data(args, args.batch_size)
        x.requires_grad_()

        if args.mode == "lsd":
            # our method

            # compute dlogp(x)/dx
            logp_u = ebm(x)
            sq = keep_grad(logp_u.sum(), x)
            fx = critic(x)
            # compute (dlogp(x)/dx)^T * f(x)
            sq_fx = (sq * fx).sum(-1)

            # compute/estimate Tr(df/dx)
            if args.exact_trace:
                tr_dfdx = exact_jacobian_trace(fx, x)
            else:
                tr_dfdx = approx_jacobian_trace(fx, x)

            stats = (sq_fx + tr_dfdx)
            loss = stats.mean()  # estimate of S(p, q)
            l2_penalty = (
                fx * fx).sum(1).mean() * args.l2  # penalty to enforce f \in F

            # adversarial!
            if args.c_iters > 0 and itr % (args.c_iters + 1) != 0:
                (-1. * loss + l2_penalty).backward()
                critic_optimizer.step()
            else:
                loss.backward()
                optimizer.step()

        elif args.mode == "sm":
            # score matching for reference
            fx = ebm(x)
            dfdx = torch.autograd.grad(fx.sum(),
                                       x,
                                       retain_graph=True,
                                       create_graph=True)[0]
            eps = torch.randn_like(dfdx)  # use hutchinson here as well
            epsH = torch.autograd.grad(dfdx,
                                       x,
                                       grad_outputs=eps,
                                       create_graph=True,
                                       retain_graph=True)[0]

            trH = (epsH * eps).sum(1)
            norm_s = (dfdx * dfdx).sum(1)

            loss = (trH + .5 * norm_s).mean()
            loss.backward()
            optimizer.step()
        else:
            assert False

        loss_meter.update(loss.item())
        time_meter.update(time.time() - end)

        if itr % args.log_freq == 0:
            log_message = (
                'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.4f}({:.4f})'.
                format(itr, time_meter.val, time_meter.avg, loss_meter.val,
                       loss_meter.avg))
            logger.info(log_message)

        if itr % args.save_freq == 0 or itr == args.niters:
            ebm.cpu()
            utils.makedirs(args.save)
            torch.save({
                'args': args,
                'state_dict': ebm.state_dict(),
            }, os.path.join(args.save, 'checkpt.pth'))
            ebm.to(device)

        if itr % args.viz_freq == 0:
            # plot dat
            plt.clf()
            npts = 100
            p_samples = toy_data.inf_train_gen(args.data, batch_size=npts**2)
            q_samples = sampler.sample(args.n_steps)

            ebm.cpu()

            x_enc = critic(x)
            xes = x_enc.detach().cpu().numpy()
            trans = xes.min()
            scale = xes.max() - xes.min()
            xes = (xes - trans) / scale * 8 - 4

            plt.figure(figsize=(4, 4))
            visualize_transform(
                [p_samples, q_samples.detach().cpu().numpy(), xes],
                ["data", "model", "embed"], [ebm], ["model"],
                npts=npts)

            fig_filename = os.path.join(args.save, 'figs',
                                        '{:04d}.png'.format(itr))
            utils.makedirs(os.path.dirname(fig_filename))
            plt.savefig(fig_filename)
            plt.close()

            ebm.to(device)
        end = time.time()

    logger.info('Training has finished, can I get a yeet?')
Ejemplo n.º 8
0
def sample_data(args, batch_size):
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    return x