示例#1
0
def get_sampler(args):
    data_dim = np.prod(args.input_size)
    if args.input_type == "binary":
        if args.sampler == "gibbs":
            sampler = samplers.PerDimGibbsSampler(data_dim, rand=False)
        elif args.sampler == "rand_gibbs":
            sampler = samplers.PerDimGibbsSampler(data_dim, rand=True)
        elif args.sampler.startswith("bg-"):
            block_size = int(args.sampler.split('-')[1])
            sampler = block_samplers.BlockGibbsSampler(data_dim, block_size)
        elif args.sampler.startswith("hb-"):
            block_size, hamming_dist = [int(v) for v in args.sampler.split('-')[1:]]
            sampler = block_samplers.HammingBallSampler(data_dim, block_size, hamming_dist)
        elif args.sampler == "gwg":
            sampler = samplers.DiffSampler(data_dim, 1,
                                           fixed_proposal=False, approx=True, multi_hop=False, temp=2.)
        elif args.sampler.startswith("gwg-"):
            n_hops = int(args.sampler.split('-')[1])
            sampler = samplers.MultiDiffSampler(data_dim, 1, approx=True, temp=2., n_samples=n_hops)
        else:
            raise ValueError("Invalid sampler...")
    else:
        if args.sampler == "gibbs":
            sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=False)
        elif args.sampler == "rand_gibbs":
            sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=True)
        elif args.sampler == "gwg":
            sampler = samplers.DiffSamplerMultiDim(data_dim, 1, approx=True, temp=2.)
        else:
            raise ValueError("invalid sampler")
    return sampler
示例#2
0
def generate_data(args):
    if args.data_model == "lattice_potts":
        model = rbm.LatticePottsModel(args.dim, args.n_state, args.sigma)
        sampler = samplers.PerDimMetropolisSampler(model.data_dim,
                                                   args.n_out,
                                                   rand=False)
    elif args.data_model == "lattice_ising":
        model = rbm.LatticeIsingModel(args.dim, args.sigma)
        sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=False)
    elif args.data_model == "lattice_ising_3d":
        model = rbm.LatticeIsingModel(args.dim, args.sigma, lattice_dim=3)
        sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=False)
        print(model.sigma)
        print(model.G)
        print(model.J)
    elif args.data_model == "er_ising":
        model = rbm.ERIsingModel(args.dim, args.degree, args.sigma)
        sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=False)
        print(model.G)
        print(model.J)
    else:
        raise ValueError

    model = model.to(args.device)
    samples = model.init_sample(args.n_samples).to(args.device)
    print("Generating {} samples from:".format(args.n_samples))
    print(model)
    for _ in tqdm(range(args.gt_steps)):
        samples = sampler.step(samples, model).detach()

    return samples.detach().cpu(), model
示例#3
0
def main(args):
    makedirs(args.save_dir)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    model = rbm.LatticePottsModel(int(args.dim), int(args.n_out), args.sigma,
                                  args.bias)
    model.to(device)
    print(device)

    if args.n_out == 3:
        plot = lambda p, x: torchvision.utils.save_image(
            x.view(x.size(0), args.dim, args.dim, 3).transpose(3, 1),
            p,
            normalize=False,
            nrow=int(x.size(0)**.5))
    else:
        plot = None

    ess_samples = model.init_sample(args.n_samples).to(device)

    hops = {}
    ess = {}
    times = {}
    chains = {}

    temps = ['dim-gibbs', 'rand-gibbs', 'gwg']
    for temp in temps:
        if temp == 'dim-gibbs':
            sampler = samplers.PerDimMetropolisSampler(args.dim**2, args.n_out)
        elif temp == "rand-gibbs":
            sampler = samplers.PerDimMetropolisSampler(args.dim**2,
                                                       args.n_out,
                                                       rand=True)
        elif "bg-" in temp:
            block_size = int(temp.split('-')[1])
            sampler = block_samplers.BlockGibbsSampler(model.data_dim,
                                                       block_size)
        elif "hb-" in temp:
            block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]]
            sampler = block_samplers.HammingBallSampler(
                model.data_dim, block_size, hamming_dist)
        elif temp == "gwg":
            sampler = samplers.DiffSamplerMultiDim(args.dim**2,
                                                   1,
                                                   approx=True,
                                                   temp=2.)
        elif "gwg-" in temp:
            n_hops = int(temp.split('-')[1])
            sampler = samplers.MultiDiffSampler(model.data_dim,
                                                1,
                                                approx=True,
                                                temp=2.,
                                                n_samples=n_hops)
        else:
            raise ValueError("Invalid sampler...")

        x = model.init_dist.sample((args.n_test_samples, )).to(device)

        times[temp] = []
        hops[temp] = []
        chain = []
        cur_time = 0.
        for i in range(args.n_steps):
            # do sampling and time it
            st = time.time()
            xhat = sampler.step(x.detach(), model).detach()
            cur_time += time.time() - st

            # compute hamming dist
            cur_hops = (x != xhat).float().view(x.size(0),
                                                -1).sum(-1).mean().item()

            # update trajectory
            x = xhat

            if i % args.subsample == 0:
                if args.ess_statistic == "dims":
                    chain.append(x.cpu()[0].view(-1).numpy()[None])
                else:
                    xc = x[0][None]
                    h = (xc != ess_samples).float().view(
                        ess_samples.size(0), -1).sum(-1)
                    chain.append(h.detach().cpu().numpy()[None])

            if i % args.viz_every == 0 and plot is not None:
                plot(
                    "/{}/temp_{}_samples_{}.png".format(
                        args.save_dir, temp, i), x)

            if i % args.print_every == 0:
                times[temp].append(cur_time)
                hops[temp].append(cur_hops)
                print("temp {}, itr = {}, hop-dist = {:.4f}".format(
                    temp, i, cur_hops))

        chain = np.concatenate(chain, 0)
        chains[temp] = chain
        ess[temp] = get_ess(chain, args.burn_in)
        print("ess = {} +/- {}".format(ess[temp].mean(), ess[temp].std()))

    ess_temps = temps
    plt.clf()
    plt.boxplot([ess[temp] for temp in ess_temps],
                labels=ess_temps,
                showfliers=False)
    plt.savefig("{}/ess.png".format(args.save_dir))

    plt.clf()
    plt.boxplot([
        ess[temp] / times[temp][-1] / (1. - args.burn_in) for temp in ess_temps
    ],
                labels=ess_temps,
                showfliers=False)
    plt.savefig("{}/ess_per_sec.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(hops[temp], label="{}".format(temp))

    plt.legend()
    plt.savefig("{}/hops.png".format(args.save_dir))

    for temp in temps:
        plt.clf()
        plt.plot(chains[temp][:, 0])
        plt.savefig("{}/trace_{}.png".format(args.save_dir, temp))

    with open("{}/results.pkl".format(args.save_dir), 'wb') as f:
        results = {'ess': ess, 'hops': hops, 'chains': chains}
        pickle.dump(results, f)
示例#4
0
def main(args):
    makedirs(args.save_dir)
    logger = open("{}/log.txt".format(args.save_dir), 'w')

    def my_print(s):
        print(s)
        logger.write(str(s) + '\n')

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # load existing data
    if args.data == "synthetic":
        train_loader, test_loader, data, ground_truth_J, ground_truth_h, ground_truth_C = utils.load_synthetic(
            args.data_file, args.batch_size)
        dim, n_out = data.size()[1:]
        ground_truth_J_norm = norm_J(ground_truth_J).to(device)
        matsave(ground_truth_J.abs().transpose(2, 1).reshape(dim * n_out, dim * n_out),
                "{}/ground_truth_J.png".format(args.save_dir))
        matsave(ground_truth_C, "{}/ground_truth_C.png".format(args.save_dir))
        matsave(ground_truth_J_norm, "{}/ground_truth_J_norm.png".format(args.save_dir))
        num_ecs = 120
        dm_indices = torch.arange(ground_truth_J_norm.size(0)).long()
    # generate the dataset
    elif args.data == "PF00018":
        train_loader, test_loader, data, num_ecs, ground_truth_J_norm, ground_truth_C = utils.load_ingraham(args)
        dim, n_out = data.size()[1:]
        ground_truth_J_norm = ground_truth_J_norm.to(device)
        matsave(ground_truth_C, "{}/ground_truth_C.png".format(args.save_dir))
        matsave(ground_truth_J_norm, "{}/ground_truth_dists.png".format(args.save_dir))
        dm_indices = torch.arange(ground_truth_J_norm.size(0)).long()

    else:
        train_loader, test_loader, data, num_ecs, ground_truth_J_norm, ground_truth_C, dm_indices = utils.load_real_protein(args)
        dim, n_out = data.size()[1:]
        ground_truth_J_norm = ground_truth_J_norm.to(device)
        matsave(ground_truth_C, "{}/ground_truth_C.png".format(args.save_dir))
        matsave(ground_truth_J_norm, "{}/ground_truth_dists.png".format(args.save_dir))

    if args.model == "lattice_potts":
        model = rbm.LatticePottsModel(int(args.dim), int(n_out), 0., 0., learn_sigma=True)
        buffer = model.init_sample(args.buffer_size)
    if args.model == "dense_potts":
        model = rbm.DensePottsModel(dim, n_out, learn_J=True, learn_bias=True)
        buffer = model.init_sample(args.buffer_size)
    elif args.model == "dense_ising":
        raise ValueError
    elif args.model == "mlp":
        raise ValueError

    model.to(device)

    # make G symmetric
    def get_J():
        j = model.J
        jt = j.transpose(0, 1).transpose(2, 3)
        return (j + jt) / 2

    def get_J_sub():
        j = get_J()
        j_sub = j[dm_indices, :][:, dm_indices]
        return j_sub

    if args.sampler == "gibbs":
        if "potts" in args.model:
            sampler = samplers.PerDimMetropolisSampler(dim, int(n_out), rand=False)
        else:
            sampler = samplers.PerDimGibbsSampler(dim, rand=False)
    elif args.sampler == "plm":
        sampler = samplers.PerDimMetropolisSampler(dim, int(n_out), rand=False)
    elif args.sampler == "rand_gibbs":
        if "potts" in args.model:
            sampler = samplers.PerDimMetropolisSampler(dim, int(n_out), rand=True)
        else:
            sampler = samplers.PerDimGibbsSampler(dim, rand=True)
    elif args.sampler == "gwg":
        if "potts" in args.model:
            sampler = samplers.DiffSamplerMultiDim(dim, 1, approx=True, temp=2.)
        else:
            sampler = samplers.DiffSampler(dim, 1, approx=True, fixed_proposal=False, temp=2.)
    else:
        assert "gwg-" in args.sampler
        n_hop = int(args.sampler.split('-')[1])
        if "potts" in args.model:
            raise ValueError
        else:
            sampler = samplers.MultiDiffSampler(model.data_dim, 1, approx=True, temp=2., n_samples=n_hop)

    my_print(device)
    my_print(model)
    my_print(buffer.size())
    my_print(sampler)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # load ckpt
    if args.ckpt_path is not None:
        d = torch.load(args.ckpt_path)
        model.load_state_dict(d['model'])
        optimizer.load_state_dict(d['optimizer'])
        sampler.load_state_dict(d['sampler'])


    # mask matrix for PLM
    L, D = model.J.size(0), model.J.size(2)
    num_node = L * D
    J_mask = torch.ones((num_node, num_node)).to(device)
    for i in range(L):
        J_mask[D * i:D * i + D, D * i:D * i + D] = 0


    itr = 0
    sq_errs = []
    rmses = []
    all_inds = list(range(args.buffer_size))
    while itr < args.n_iters:
        for x in train_loader:
            if args.data == "synthetic":
                x = x[0].to(device)
                weights = torch.ones((x.size(0),)).to(device)
            else:
                weights = x[1].to(device)
                if args.unweighted:
                    weights = torch.ones_like(weights)
                x = x[0].to(device)

            if args.sampler == "plm":
                plm_J = model.J.transpose(2, 1).reshape(dim * n_out, dim * n_out)
                logits = torch.matmul(x.view(x.size(0), -1), plm_J * J_mask) + model.bias.view(-1)[None]
                x_inds = (torch.arange(x.size(-1))[None, None].to(x.device) * x).sum(-1)
                cross_entropy = nn.functional.cross_entropy(
                    input=logits.reshape((-1, D)),
                    target=x_inds.view(-1).long(),
                    reduce=False)
                cross_entropy = torch.sum(cross_entropy.reshape((-1, L)), -1)
                loss = (cross_entropy * weights).mean()

            else:
                buffer_inds = np.random.choice(all_inds, args.batch_size, replace=False)
                x_fake = buffer[buffer_inds].to(device)
                for k in range(args.sampling_steps):
                    x_fake = sampler.step(x_fake.detach(), model).detach()

                buffer[buffer_inds] = x_fake.detach().cpu()

                logp_real = (model(x).squeeze() * weights).mean()
                logp_fake = model(x_fake).squeeze().mean()

                obj = logp_real - logp_fake
                loss = -obj

            # add l1 reg
            loss += args.l1 * norm_J(get_J()).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if itr % args.print_every == 0:
                if args.sampler == "plm":
                    my_print("({}) loss = {:.4f}".format(itr, loss.item()))
                else:
                    my_print("({}) log p(real) = {:.4f}, log p(fake) = {:.4f}, diff = {:.4f}, hops = {:.4f}".format(itr,
                                                                                                  logp_real.item(),
                                                                                                  logp_fake.item(),
                                                                                                  obj.item(),
                                                                                                  sampler._hops))

                sq_err = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).sum()
                rmse = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).mean().sqrt()
                inds = torch.triu_indices(ground_truth_C.size(0), ground_truth_C.size(1), 1)
                C_inds = ground_truth_C[inds[0], inds[1]]
                J_inds = norm_J(get_J_sub())[inds[0], inds[1]]
                J_inds_sorted = torch.sort(J_inds, descending=True).indices
                C_inds_sorted = C_inds[J_inds_sorted]
                C_cumsum = C_inds_sorted.cumsum(0)
                arange = torch.arange(C_cumsum.size(0)) + 1
                acc_at = C_cumsum.float() / arange.float()
                my_print("\t err^2 = {:.4f}, rmse = {:.4f}, acc @ 50 = {:.4f}, acc @ 75 = {:.4f}, acc @ 100 = {:.4f}".format(sq_err, rmse,
                                                                                                         acc_at[50],
                                                                                                         acc_at[75],
                                                                                                         acc_at[100]))
                logger.flush()


            if itr % args.viz_every == 0:
                sq_err = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).sum()
                rmse = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).mean().sqrt()

                sq_errs.append(sq_err.item())
                plt.clf()
                plt.plot(sq_errs, label="sq_err")
                plt.legend()
                plt.savefig("{}/sq_err.png".format(args.save_dir))

                rmses.append(rmse.item())
                plt.clf()
                plt.plot(rmses, label="rmse")
                plt.legend()
                plt.savefig("{}/rmse.png".format(args.save_dir))


                matsave(get_J_sub().abs().transpose(2, 1).reshape(dm_indices.size(0) * n_out,
                                                                  dm_indices.size(0) * n_out),
                        "{}/model_J_{}_sub.png".format(args.save_dir, itr))
                matsave(norm_J(get_J_sub()), "{}/model_J_norm_{}_sub.png".format(args.save_dir, itr))

                matsave(get_J().abs().transpose(2, 1).reshape(dim * n_out, dim * n_out),
                        "{}/model_J_{}.png".format(args.save_dir, itr))
                matsave(norm_J(get_J()), "{}/model_J_norm_{}.png".format(args.save_dir, itr))

                inds = torch.triu_indices(ground_truth_C.size(0), ground_truth_C.size(1), 1)
                C_inds = ground_truth_C[inds[0], inds[1]]
                J_inds = norm_J(get_J_sub())[inds[0], inds[1]]
                J_inds_sorted = torch.sort(J_inds, descending=True).indices
                C_inds_sorted = C_inds[J_inds_sorted]
                C_cumsum = C_inds_sorted.cumsum(0)
                arange = torch.arange(C_cumsum.size(0)) + 1
                acc_at = C_cumsum.float() / arange.float()

                plt.clf()
                plt.plot(acc_at[:num_ecs].detach().cpu().numpy())
                plt.savefig("{}/acc_at_{}.png".format(args.save_dir, itr))

            if itr % args.ckpt_every == 0:
                my_print("Saving checkpoint to {}/ckpt.pt".format(args.save_dir))
                torch.save({
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "sampler": sampler.state_dict()
                }, "{}/ckpt.pt".format(args.save_dir))


            itr += 1

            if itr > args.n_iters:
                sq_err = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).sum()
                rmse = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).mean().sqrt()
                with open("{}/sq_err.txt".format(args.save_dir), 'w') as f:
                    f.write(str(sq_err))
                with open("{}/rmse.txt".format(args.save_dir), 'w') as f:
                    f.write(str(rmse))

                torch.save({
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "sampler": sampler.state_dict()
                }, "{}/ckpt.pt".format(args.save_dir))

                quit()
示例#5
0
def main(args):
    makedirs(args.save_dir)
    logger = open("{}/log.txt".format(args.save_dir), 'w')

    def my_print(s):
        print(s)
        logger.write(str(s) + '\n')

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # load existing data
    if args.model == "lattice_potts":
        model = rbm.LatticePottsModel(int(args.dim),
                                      int(args.n_out),
                                      0.,
                                      0.,
                                      learn_sigma=True)
    if args.model == "dense_potts":
        model = rbm.DensePottsModel(args.dim,
                                    args.n_out,
                                    learn_J=True,
                                    learn_bias=True)
    else:
        raise ValueError

    model.to(device)

    if args.sampler == "gibbs":
        sampler = samplers.PerDimMetropolisSampler(args.dim,
                                                   int(args.n_out),
                                                   rand=False)
    elif args.sampler == "rand_gibbs":
        sampler = samplers.PerDimMetropolisSampler(args.dim,
                                                   int(args.n_out),
                                                   rand=True)
    elif args.sampler == "gwg":
        sampler = samplers.DiffSamplerMultiDim(args.dim,
                                               1,
                                               approx=True,
                                               temp=2.)
    else:
        raise ValueError

    my_print(device)
    my_print(model)
    my_print(sampler)

    # load ckpt
    my_print("Loading...")
    if args.ckpt_path is not None:
        d = torch.load(args.ckpt_path)
        model.load_state_dict(d['model'])
    my_print("Loaded!")
    betas = np.linspace(0., 1., args.n_iters)

    samples = model.init_sample(args.n_samples)
    log_w = torch.zeros((args.n_samples, )).to(device)
    log_w += model.bias.logsumexp(-1).sum()

    logZs = []
    for itr, beta_k in enumerate(betas):
        if itr == 0:
            continue  # skip 0

        beta_km1 = betas[itr - 1]

        # udpate importance weights
        with torch.no_grad():
            log_w = log_w + model(samples, beta=beta_k) - model(
                samples, beta_km1)
        # update samples
        model_k = lambda x: model(x, beta=beta_k)
        for d in range(args.steps_per_iter):
            samples = sampler.step(samples.detach(), model_k).detach()

        if itr % args.print_every == 0:
            logZ = log_w.logsumexp(0) - np.log(args.n_samples)
            logZs.append(logZ.item())
            my_print("({}) beta = {}, log Z = {:.4f}".format(
                itr, beta_k, logZ.item()))
            logger.flush()

        if itr % args.viz_every == 0:
            plt.clf()
            plt.plot(logZs, label="log(Z)")
            plt.legend()
            plt.savefig("{}/logZ.png".format(args.save_dir))

    logZ_final = log_w.logsumexp(0) - np.log(args.n_samples)
    my_print("Final log(Z) = {:.4f}".format(logZ_final))
示例#6
0
def main(args):
    makedirs(args.save_dir)
    logger = open("{}/log.txt".format(args.save_dir), 'w')

    def my_print(s):
        print(s)
        logger.write(str(s) + '\n')

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # load existing data
    if args.data == "mnist" or args.data_file is not None:
        train_loader, test_loader, plot, viz = utils.get_data(args)
    # generate the dataset
    else:
        data, data_model = utils.generate_data(args)
        my_print(
            "we have created your data, but what have you done for me lately?????"
        )
        with open("{}/data.pkl".format(args.save_dir), 'wb') as f:
            pickle.dump(data, f)
        if args.data_model == "er_ising":
            ground_truth_J = data_model.J.detach().cpu()
            with open("{}/J.pkl".format(args.save_dir), 'wb') as f:
                pickle.dump(ground_truth_J, f)
        quit()

    if args.model == "lattice_potts":
        model = rbm.LatticePottsModel(int(args.dim),
                                      int(args.n_state),
                                      0.,
                                      0.,
                                      learn_sigma=True)
        buffer = model.init_sample(args.buffer_size)
    elif args.model == "lattice_ising":
        model = rbm.LatticeIsingModel(int(args.dim), 0., 0., learn_sigma=True)
        buffer = model.init_sample(args.buffer_size)
    elif args.model == "lattice_ising_3d":
        model = rbm.LatticeIsingModel(int(args.dim),
                                      .2,
                                      learn_G=True,
                                      lattice_dim=3)
        ground_truth_J = model.J.clone().to(device)
        model.G.data = torch.randn_like(model.G.data) * .01
        model.sigma.data = torch.ones_like(model.sigma.data)
        buffer = model.init_sample(args.buffer_size)
        plt.clf()
        plt.matshow(ground_truth_J.detach().cpu().numpy())
        plt.savefig("{}/ground_truth.png".format(args.save_dir))
    elif args.model == "lattice_ising_2d":
        model = rbm.LatticeIsingModel(int(args.dim),
                                      args.sigma,
                                      learn_G=True,
                                      lattice_dim=2)
        ground_truth_J = model.J.clone().to(device)
        model.G.data = torch.randn_like(model.G.data) * .01
        model.sigma.data = torch.ones_like(model.sigma.data)
        buffer = model.init_sample(args.buffer_size)
        plt.clf()
        plt.matshow(ground_truth_J.detach().cpu().numpy())
        plt.savefig("{}/ground_truth.png".format(args.save_dir))
    elif args.model == "er_ising":
        model = rbm.ERIsingModel(int(args.dim), 2, learn_G=True)
        model.G.data = torch.randn_like(model.G.data) * .01
        buffer = model.init_sample(args.buffer_size)
        with open(args.graph_file, 'rb') as f:
            ground_truth_J = pickle.load(f)
            plt.clf()
            plt.matshow(ground_truth_J.detach().cpu().numpy())
            plt.savefig("{}/ground_truth.png".format(args.save_dir))
        ground_truth_J = ground_truth_J.to(device)
    elif args.model == "rbm":
        model = rbm.BernoulliRBM(args.dim, args.n_hidden)
        buffer = model.init_dist.sample((args.buffer_size, ))
    elif args.model == "dense_potts":
        raise ValueError
    elif args.model == "dense_ising":
        raise ValueError
    elif args.model == "mlp":
        raise ValueError

    model.to(device)
    buffer = buffer.to(device)

    # make G symmetric
    def get_J():
        j = model.J
        return (j + j.t()) / 2

    if args.sampler == "gibbs":
        if "potts" in args.model:
            sampler = samplers.PerDimMetropolisSampler(model.data_dim,
                                                       int(args.n_state),
                                                       rand=False)
        else:
            sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=False)
    elif args.sampler == "rand_gibbs":
        if "potts" in args.model:
            sampler = samplers.PerDimMetropolisSampler(model.data_dim,
                                                       int(args.n_state),
                                                       rand=True)
        else:
            sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=True)
    elif args.sampler == "gwg":
        if "potts" in args.model:
            sampler = samplers.DiffSamplerMultiDim(model.data_dim,
                                                   1,
                                                   approx=True,
                                                   temp=2.)
        else:
            sampler = samplers.DiffSampler(model.data_dim,
                                           1,
                                           approx=True,
                                           fixed_proposal=False,
                                           temp=2.)
    else:
        assert "gwg-" in args.sampler
        n_hop = int(args.sampler.split('-')[1])
        if "potts" in args.model:
            raise ValueError
        else:
            sampler = samplers.MultiDiffSampler(model.data_dim,
                                                1,
                                                approx=True,
                                                temp=2.,
                                                n_samples=n_hop)

    my_print(device)
    my_print(model)
    my_print(buffer.size())
    my_print(sampler)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    itr = 0
    sigmas = []
    sq_errs = []
    rmses = []
    while itr < args.n_iters:
        for x in train_loader:
            x = x[0].to(device)

            for k in range(args.sampling_steps):
                buffer = sampler.step(buffer.detach(), model).detach()

            logp_real = model(x).squeeze().mean()
            logp_fake = model(buffer).squeeze().mean()

            obj = logp_real - logp_fake
            loss = -obj
            loss += args.l1 * get_J().abs().sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.G.data *= (1. - torch.eye(model.G.data.size(0))).to(model.G)

            if itr % args.print_every == 0:
                my_print(
                    "({}) log p(real) = {:.4f}, log p(fake) = {:.4f}, diff = {:.4f}, hops = {:.4f}"
                    .format(itr, logp_real.item(), logp_fake.item(),
                            obj.item(), sampler._hops))
                if args.model in ("lattice_potts", "lattice_ising"):
                    my_print(
                        "\tsigma true = {:.4f}, current sigma = {:.4f}".format(
                            args.sigma, model.sigma.data.item()))
                else:
                    sq_err = ((ground_truth_J - get_J())**2).sum()
                    rmse = ((ground_truth_J - get_J())**2).mean().sqrt()
                    my_print("\t err^2 = {:.4f}, rmse = {:.4f}".format(
                        sq_err, rmse))
                    print(ground_truth_J)
                    print(get_J())

            if itr % args.viz_every == 0:
                if args.model in ("lattice_potts", "lattice_ising"):
                    sigmas.append(model.sigma.data.item())
                    plt.clf()
                    plt.plot(sigmas, label="model")
                    plt.plot([args.sigma for s in sigmas], label="gt")
                    plt.legend()
                    plt.savefig("{}/sigma.png".format(args.save_dir))
                else:
                    sq_err = ((ground_truth_J - get_J())**2).sum()
                    sq_errs.append(sq_err.item())
                    plt.clf()
                    plt.plot(sq_errs, label="sq_err")
                    plt.legend()
                    plt.savefig("{}/sq_err.png".format(args.save_dir))

                    rmse = ((ground_truth_J - get_J())**2).mean().sqrt()
                    rmses.append(rmse.item())
                    plt.clf()
                    plt.plot(rmses, label="rmse")
                    plt.legend()
                    plt.savefig("{}/rmse.png".format(args.save_dir))

                    plt.clf()
                    plt.matshow(get_J().detach().cpu().numpy())
                    plt.savefig("{}/model_{}.png".format(args.save_dir, itr))

                plot("{}/data_{}.png".format(args.save_dir, itr),
                     x.detach().cpu())
                plot("{}/buffer_{}.png".format(args.save_dir, itr),
                     buffer[:args.batch_size].detach().cpu())

            itr += 1

            if itr > args.n_iters:
                if args.model in ("lattice_potts", "lattice_ising"):
                    final_sigma = model.sigma.data.item()
                    with open("{}/sigma.txt".format(args.save_dir), 'w') as f:
                        f.write(str(final_sigma))
                else:
                    sq_err = ((ground_truth_J - get_J())**2).sum().item()
                    rmse = ((ground_truth_J - get_J())**2).mean().sqrt().item()
                    with open("{}/sq_err.txt".format(args.save_dir), 'w') as f:
                        f.write(str(sq_err))
                    with open("{}/rmse.txt".format(args.save_dir), 'w') as f:
                        f.write(str(rmse))

                quit()