예제 #1
0
def main():
    save_dir = args.save_dir + f"/{args.dataset}/{args.width}/{args.hidden_dim}"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    t0 = 0
    t1 = 10

    p_z0 = torch.distributions.MultivariateNormal(
        loc=torch.tensor([0.0, 0.0]).to(device),
        covariance_matrix=torch.tensor([[0.3, 0.0], [0.0, 0.3]]).to(device))

    odefunc = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width)
    if device != "cpu":
        odefunc = odefunc.cuda()

    optimizer = optim.Adam(odefunc.parameters(), lr=1e-3, weight_decay=0.)

    x_test, logp_diff_t1_test = get_batch(args.num_samples * 20, args.dataset)

    # Train model
    for itr in tqdm(range(args.epochs + 1)):
        optimizer.zero_grad()

        x, logp_diff_t1 = get_batch(args.num_samples, args.dataset)

        loss = calc_loss(odefunc, x, logp_diff_t1, t0, t1, p_z0)

        loss.backward()
        optimizer.step()

        best_loss = np.inf
        if itr % 100 == 0:
            z_t, logp_diff_t = odeint(
                odefunc,
                (x_test, logp_diff_t1_test),
                torch.tensor([t1, t0]).type(torch.float32).to(device),
                atol=1e-5,
                rtol=1e-5,
                method='dopri5',
            )

            z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]

            logp_x = p_z0.log_prob(z_t0).to(device) - logp_diff_t0.view(-1)
            loss = -logp_x.mean(0)

            print(f"{itr} Test loss: {loss}")

            if loss < best_loss:
                best_loss = loss
                torch.save(odefunc.state_dict(), f"{save_dir}/best_model.pt")

            torch.save(odefunc.state_dict(), f"{save_dir}/last_model.pt")

            plt.figure(figsize=(4, 4), dpi=200)
            plt.hist2d(*z_t0.detach().cpu().numpy().T,
                       bins=300,
                       density=True,
                       range=[[-2, 2], [-2, 2]])
            plt.axis('off')
            plt.gca().invert_yaxis()
            plt.margins(0, 0)

            plt.savefig(save_dir + f"/tgt_itr_{itr:05d}.jpg",
                        pad_inches=0,
                        bbox_inches='tight')
            plt.close()

    odefunc.load_state_dict(torch.load(f"{save_dir}/best_model.pt"))

    # Generate evolution of sampled points
    z_t0 = p_z0.sample([30000]).to(device)
    logp_diff_t0 = torch.zeros(30000, 1).type(torch.float32).to(device)

    z_t, logp_diff_t = odeint(
        odefunc,
        (z_t0, logp_diff_t0),
        torch.tensor(np.linspace(t0, t1, 21)).to(device),
        atol=1e-5,
        rtol=1e-5,
        method='dopri5',
    )

    for (t, z) in zip(np.linspace(t0, t1, 21), z_t.detach().cpu().numpy()):
        plt.figure(figsize=(4, 4), dpi=200)
        plt.hist2d(*z.T, bins=300, density=True, range=[[-2, 2], [-2, 2]])
        plt.axis('off')
        plt.gca().invert_yaxis()
        plt.margins(0, 0)

        plt.savefig(save_dir + f"/samples_{t:f}.jpg",
                    pad_inches=0,
                    bbox_inches='tight')
        plt.close()

    # Generate evolution of density
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T

    z_t1 = torch.tensor(points).type(torch.float32).to(device)
    logp_diff_t1 = torch.zeros(z_t1.shape[0], 1).type(torch.float32).to(device)

    z_t, logp_diff_t = odeint(
        odefunc,
        (z_t1, logp_diff_t1),
        torch.tensor(np.linspace(t1, t0, 21)).to(device),
        atol=1e-5,
        rtol=1e-5,
        method='dopri5',
    )

    for (t, z, logp_diff) in zip(np.linspace(t0, t1, 21), z_t, logp_diff_t):
        logp = p_z0.log_prob(z) - logp_diff.view(-1)

        plt.figure(figsize=(4, 4), dpi=200)
        plt.tricontourf(*z_t1.detach().cpu().numpy().T,
                        np.exp(logp.detach().cpu().numpy()), 200)
        plt.tight_layout()
        plt.axis('off')
        plt.gca().invert_yaxis()
        plt.margins(0, 0)

        plt.savefig(save_dir + f"/density_{t:f}.jpg",
                    pad_inches=0,
                    bbox_inches='tight')
        plt.close()
예제 #2
0
def main():
    save_dir = args.save_dir + f"/{args.energy_fun}/{args.width}/{args.hidden_dim}"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    t0 = 0
    t1 = 10

    norm = torch.distributions.MultivariateNormal(loc=torch.tensor([0.0, 0.0]).to(device),
                                                  covariance_matrix=torch.tensor([[1.0, 0.0], [0.0, 1.0]]).to(device))

    odefunc = CNF(in_out_dim=2, hidden_dim=args.hidden_dim, width=args.width)
    if device != "cpu":
        odefunc = odefunc.cuda()

    optimizer = optim.Adam(odefunc.parameters(), lr=1e-3, weight_decay=0.)

    z0_test = norm.sample([args.num_samples * 20]).to(device)
    logp_z0_test = norm.log_prob(z0_test).unsqueeze(dim=-1).to(device)

    if args.energy_fun == 1:
        energy_fun = energy_function_1
    elif args.energy_fun == 2:
        energy_fun = energy_function_2
    elif args.energy_fun == 3:
        energy_fun = energy_function_3
    elif args.energy_fun == 4:
        energy_fun = energy_function_4
    else:
        raise Exception("Energy function not implemented.")

    # Train model
    if args.train:
        for itr in tqdm(range(args.epochs + 1)):
            optimizer.zero_grad()

            z0 = norm.sample([args.num_samples]).to(device)
            logp_z0 = norm.log_prob(z0).unsqueeze(dim=-1).to(device)

            loss = calc_loss(odefunc, z0, logp_z0, t0, t1, energy_fun)

            loss.backward()
            optimizer.step()

            best_loss = np.inf
            if itr % 100 == 0:
                z_t, logp_diff_t = odeint(
                    odefunc,
                    (z0_test, logp_z0_test),
                    torch.tensor([t0, t1]).type(torch.float32).to(device),
                    atol=1e-5,
                    rtol=1e-5,
                    method='dopri5',
                )

                x, logp_x = z_t[-1], logp_diff_t[-1]

                loss = (logp_x.mean(0) + energy_fun(x).mean(0)).item()

                print(f"{itr} Test loss: {loss}")

                if loss < best_loss:
                    best_loss = loss
                    torch.save(odefunc.state_dict(),
                            f"{save_dir}/best_model.pt")

                torch.save(odefunc.state_dict(),
                            f"{save_dir}/last_model.pt")

                plt.figure(figsize=(4, 4), dpi=200)
                plt.hist2d(*x.detach().cpu().numpy().T,
                           bins=300, density=True, range=[[-4, 4], [-4, 4]])
                plt.axis('off')
                plt.gca().invert_yaxis()
                plt.margins(0, 0)

                plt.savefig(save_dir + f"/tgt_itr_{itr:05d}.jpg",
                            pad_inches=0, bbox_inches='tight')
                plt.close()

    odefunc.load_state_dict(torch.load(f"{save_dir}/best_model.pt"))

    if not args.train:
        z_t, logp_diff_t = odeint(
            odefunc,
            (z0_test, logp_z0_test),
            torch.tensor([t0, t1]).type(torch.float32).to(device),
            atol=1e-5,
            rtol=1e-5,
            method='dopri5',
        )

        x, logp_x = z_t[-1], logp_diff_t[-1]

        best_loss = (logp_x.mean(0) + energy_fun(x).mean(0)).item()

    # Generate evolution of density
    x = np.linspace(-6, 6, 600)
    y = np.linspace(-6, 6, 600)
    points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T

    z_t0 = torch.tensor(points).type(torch.float32).to(device)
    logp_t0 = norm.log_prob(z_t0).unsqueeze(dim=-1).to(device)

    z_t, logp_t = odeint(
        odefunc,
        (z_t0, logp_t0),
        torch.tensor(np.linspace(t0, t1, 21)).to(device),
        atol=1e-5,
        rtol=1e-5,
        method='dopri5',
    )

    for (t, z, logp) in zip(np.linspace(t0, t1, 21), z_t, logp_t):
        plt.figure(figsize=(4, 4), dpi=200)
        plt.tricontourf(*z.detach().cpu().numpy().T,
                        np.exp(logp.view(-1).detach().cpu().numpy()), 200)
        plt.xlim([-4, 4])
        plt.ylim([-4, 4])
        plt.tight_layout()
        plt.axis('off')
        plt.gca().invert_yaxis()
        plt.margins(0, 0)

        plt.savefig(save_dir + f"/density_{best_loss:.10f}_{t:f}.jpg",
                    pad_inches=0, bbox_inches='tight')
        plt.close()