Beispiel #1
0
def get_sampler(args):
    data_dim = np.prod(args.input_size)
    if args.input_type == "binary":
        if args.sampler == "gibbs":
            sampler = samplers.PerDimGibbsSampler(data_dim, rand=False)
        elif args.sampler == "rand_gibbs":
            sampler = samplers.PerDimGibbsSampler(data_dim, rand=True)
        elif args.sampler.startswith("bg-"):
            block_size = int(args.sampler.split('-')[1])
            sampler = block_samplers.BlockGibbsSampler(data_dim, block_size)
        elif args.sampler.startswith("hb-"):
            block_size, hamming_dist = [int(v) for v in args.sampler.split('-')[1:]]
            sampler = block_samplers.HammingBallSampler(data_dim, block_size, hamming_dist)
        elif args.sampler == "gwg":
            sampler = samplers.DiffSampler(data_dim, 1,
                                           fixed_proposal=False, approx=True, multi_hop=False, temp=2.)
        elif args.sampler.startswith("gwg-"):
            n_hops = int(args.sampler.split('-')[1])
            sampler = samplers.MultiDiffSampler(data_dim, 1, approx=True, temp=2., n_samples=n_hops)
        else:
            raise ValueError("Invalid sampler...")
    else:
        if args.sampler == "gibbs":
            sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=False)
        elif args.sampler == "rand_gibbs":
            sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=True)
        elif args.sampler == "gwg":
            sampler = samplers.DiffSamplerMultiDim(data_dim, 1, approx=True, temp=2.)
        else:
            raise ValueError("invalid sampler")
    return sampler
Beispiel #2
0
def main(args):
    makedirs(args.save_dir)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    model = rbm.LatticePottsModel(int(args.dim), int(args.n_out), args.sigma,
                                  args.bias)
    model.to(device)
    print(device)

    if args.n_out == 3:
        plot = lambda p, x: torchvision.utils.save_image(
            x.view(x.size(0), args.dim, args.dim, 3).transpose(3, 1),
            p,
            normalize=False,
            nrow=int(x.size(0)**.5))
    else:
        plot = None

    ess_samples = model.init_sample(args.n_samples).to(device)

    hops = {}
    ess = {}
    times = {}
    chains = {}

    temps = ['dim-gibbs', 'rand-gibbs', 'gwg']
    for temp in temps:
        if temp == 'dim-gibbs':
            sampler = samplers.PerDimMetropolisSampler(args.dim**2, args.n_out)
        elif temp == "rand-gibbs":
            sampler = samplers.PerDimMetropolisSampler(args.dim**2,
                                                       args.n_out,
                                                       rand=True)
        elif "bg-" in temp:
            block_size = int(temp.split('-')[1])
            sampler = block_samplers.BlockGibbsSampler(model.data_dim,
                                                       block_size)
        elif "hb-" in temp:
            block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]]
            sampler = block_samplers.HammingBallSampler(
                model.data_dim, block_size, hamming_dist)
        elif temp == "gwg":
            sampler = samplers.DiffSamplerMultiDim(args.dim**2,
                                                   1,
                                                   approx=True,
                                                   temp=2.)
        elif "gwg-" in temp:
            n_hops = int(temp.split('-')[1])
            sampler = samplers.MultiDiffSampler(model.data_dim,
                                                1,
                                                approx=True,
                                                temp=2.,
                                                n_samples=n_hops)
        else:
            raise ValueError("Invalid sampler...")

        x = model.init_dist.sample((args.n_test_samples, )).to(device)

        times[temp] = []
        hops[temp] = []
        chain = []
        cur_time = 0.
        for i in range(args.n_steps):
            # do sampling and time it
            st = time.time()
            xhat = sampler.step(x.detach(), model).detach()
            cur_time += time.time() - st

            # compute hamming dist
            cur_hops = (x != xhat).float().view(x.size(0),
                                                -1).sum(-1).mean().item()

            # update trajectory
            x = xhat

            if i % args.subsample == 0:
                if args.ess_statistic == "dims":
                    chain.append(x.cpu()[0].view(-1).numpy()[None])
                else:
                    xc = x[0][None]
                    h = (xc != ess_samples).float().view(
                        ess_samples.size(0), -1).sum(-1)
                    chain.append(h.detach().cpu().numpy()[None])

            if i % args.viz_every == 0 and plot is not None:
                plot(
                    "/{}/temp_{}_samples_{}.png".format(
                        args.save_dir, temp, i), x)

            if i % args.print_every == 0:
                times[temp].append(cur_time)
                hops[temp].append(cur_hops)
                print("temp {}, itr = {}, hop-dist = {:.4f}".format(
                    temp, i, cur_hops))

        chain = np.concatenate(chain, 0)
        chains[temp] = chain
        ess[temp] = get_ess(chain, args.burn_in)
        print("ess = {} +/- {}".format(ess[temp].mean(), ess[temp].std()))

    ess_temps = temps
    plt.clf()
    plt.boxplot([ess[temp] for temp in ess_temps],
                labels=ess_temps,
                showfliers=False)
    plt.savefig("{}/ess.png".format(args.save_dir))

    plt.clf()
    plt.boxplot([
        ess[temp] / times[temp][-1] / (1. - args.burn_in) for temp in ess_temps
    ],
                labels=ess_temps,
                showfliers=False)
    plt.savefig("{}/ess_per_sec.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(hops[temp], label="{}".format(temp))

    plt.legend()
    plt.savefig("{}/hops.png".format(args.save_dir))

    for temp in temps:
        plt.clf()
        plt.plot(chains[temp][:, 0])
        plt.savefig("{}/trace_{}.png".format(args.save_dir, temp))

    with open("{}/results.pkl".format(args.save_dir), 'wb') as f:
        results = {'ess': ess, 'hops': hops, 'chains': chains}
        pickle.dump(results, f)
Beispiel #3
0
def main(args):
    makedirs("{}/sources".format(args.save_dir))

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    W = args.W_init_sigma * torch.randn((args.K,))
    W0 = args.W_init_sigma * torch.randn((1,))
    p = args.X_keep_prob * torch.ones((args.K,))
    v = args.X0_mean * torch.ones((args.K,))

    model = fhmm.FHMM(args.N, args.K, W, W0, args.obs_sigma, p, v, alt_logpx=args.alt)
    model.to(device)
    print("device is", device)

    # generate data
    Xgt = model.sample_X(1)
    p_y_given_Xgt = model.p_y_given_x(Xgt)

    mu = p_y_given_Xgt.loc
    mu_true = mu[0]
    plt.clf()
    plt.plot(mu_true.detach().cpu().numpy(), label="mean")
    ygt = p_y_given_Xgt.sample()[0]
    plt.plot(ygt.detach().cpu().numpy(), label='sample')
    plt.legend()
    plt.savefig("{}/data.png".format(args.save_dir))
    ygt = ygt.to(device)

    for k in range(args.K):
        plt.clf()
        plt.plot(Xgt[0, :, k].detach().cpu().numpy())
        plt.savefig("{}/sources/x_{}.png".format(args.save_dir, k))


    logp_joint_real = model.log_p_joint(ygt, Xgt).item()
    print("joint likelihood of real data is {}".format(logp_joint_real))

    log_joints = {}
    diffs = {}
    times = {}
    recons = {}
    ars = {}
    hops = {}
    phops = {}
    mus = {}

    dim = args.K * args.N
    x_init = model.sample_X(args.n_test_samples).to(device)
    samp_model = lambda _x: model.log_p_joint(ygt, _x)

    temps = ['bg-1', 'bg-2', 'hb-10-1', 'gwg', 'gwg-3', 'gwg-5']
    for temp in temps:
        makedirs("{}/{}".format(args.save_dir, temp))
        if temp == 'dim-gibbs':
            sampler = samplers.PerDimGibbsSampler(dim)
        elif temp == "rand-gibbs":
            sampler = samplers.PerDimGibbsSampler(dim, rand=True)
        elif "bg-" in temp:
            block_size = int(temp.split('-')[1])
            sampler = block_samplers.BlockGibbsSampler(dim, block_size)
        elif "hb-" in temp:
            block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]]
            sampler = block_samplers.HammingBallSampler(dim, block_size, hamming_dist)
        elif temp == "gwg":
            sampler = samplers.DiffSampler(dim, 1,
                                           fixed_proposal=False, approx=True, multi_hop=False, temp=2.)
        elif "gwg-" in temp:
            n_hops = int(temp.split('-')[1])
            sampler = samplers.MultiDiffSampler(dim, 1,
                                                approx=True, temp=2., n_samples=n_hops)
        else:
            raise ValueError("Invalid sampler...")
        
        x = x_init.clone().view(x_init.size(0), -1)

        diffs[temp] = []

        log_joints[temp] = []
        ars[temp] = []
        hops[temp] = []
        phops[temp] = []
        recons[temp] = []
        start_time = time.time()
        for i in range(args.n_steps + 1):
            if args.anneal is None:
                sm = samp_model
            else:
                s = np.linspace(args.anneal, args.obs_sigma, args.n_steps + 1)[i]
                sm = lambda _x: model.log_p_joint(ygt, _x, sigma=s)
            xhat = sampler.step(x.detach(), sm).detach()

            # compute hamming dist
            cur_hops = (x != xhat).float().sum(-1).mean().item()
            # update trajectory
            x = xhat

            if i % 1000 == 0:
                p_y_given_x = model.p_y_given_x(x)
                mu = p_y_given_x.loc
                plt.clf()
                plt.plot(mu_true.detach().cpu().numpy(), label="true")
                plt.plot(mu[0].detach().cpu().numpy() + .01, label='mu0')
                plt.plot(mu[1].detach().cpu().numpy() - .01, label='mu1')
                plt.legend()
                plt.savefig("{}/{}/mean_{}.png".format(args.save_dir, temp, i))
                mus[temp] = mu[0].detach().cpu().numpy()

            if i % 10 == 0:
                p_y_given_x = model.p_y_given_x(x)
                mu = p_y_given_x.loc
                err = ((mu - ygt[None]) ** 2).sum(1).mean()
                recons[temp].append(err.item())

                log_j = model.log_p_joint(ygt, x)
                diff = (x.view(x.size(0), args.N, args.K) != Xgt).float().view(x.size(0), -1).mean(1)
                log_joints[temp].append(log_j.mean().item())
                diffs[temp].append(diff.mean().item())
                hops[temp].append(cur_hops)
                print("temp {}, itr = {}, log-joint = {:.4f}, "
                      "hop-dist = {:.4f}, recons = {:.4f}".format(temp, i, log_j.mean().item(), cur_hops, err.item()))

        for k in range(args.K):
            plt.clf()
            xr = x.view(x.size(0), args.N, args.K)
            plt.plot(xr[0, :, k].detach().cpu().numpy())
            plt.savefig("{}/{}/source_{}.png".format(args.save_dir, temp, k))

        times[temp] = time.time() - start_time


    plt.clf()
    for temp in temps:
        plt.plot(log_joints[temp], label=temp)
    plt.plot([logp_joint_real for _ in log_joints[temp]], label="true")
    plt.legend()
    plt.savefig("{}/joints.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(recons[temp], label=temp)
    plt.legend()
    plt.savefig("{}/recons.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(diffs[temp], label=temp)
    plt.legend()
    plt.savefig("{}/errs.png".format(args.save_dir))

    plt.clf()
    for i, temp in enumerate(temps):
        plt.plot(mus[temp] + float(i) * .01, label=temp)
    plt.plot(mu_true.detach().cpu().numpy(), label="true")
    plt.legend()
    plt.savefig("{}/mean.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(hops[temp], label="{}".format(temp))

    plt.legend()
    plt.savefig("{}/hops.png".format(args.save_dir))

    with open("{}/results.pkl".format(args.save_dir), 'wb') as f:
        results = {
            'hops': hops,
            'recons': recons,
            'joints': log_joints,
        }
        pickle.dump(results, f)
Beispiel #4
0
def main(args):
    makedirs(args.save_dir)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    model = rbm.BernoulliRBM(args.n_visible, args.n_hidden)
    model.to(device)
    print(device)

    if args.data == "mnist":
        assert args.n_visible == 784
        train_loader, test_loader, plot, viz = utils.get_data(args)

        init_data = []
        for x, _ in train_loader:
            init_data.append(x)
        init_data = torch.cat(init_data, 0)
        init_mean = init_data.mean(0).clamp(.01, .99)

        model = rbm.BernoulliRBM(args.n_visible,
                                 args.n_hidden,
                                 data_mean=init_mean)
        model.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.rbm_lr)

        # train!
        itr = 0
        for x, _ in train_loader:
            x = x.to(device)
            xhat = model.gibbs_sample(v=x, n_steps=args.cd)

            d = model.logp_v_unnorm(x)
            m = model.logp_v_unnorm(xhat)

            obj = d - m
            loss = -obj.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if itr % args.print_every == 0:
                print(
                    "{} | log p(data) = {:.4f}, log p(model) = {:.4f}, diff = {:.4f}"
                    .format(itr, d.mean(), m.mean(), (d - m).mean()))

    else:
        model.W.data = torch.randn_like(model.W.data) * (.05**.5)
        model.b_v.data = torch.randn_like(model.b_v.data) * 1.0
        model.b_h.data = torch.randn_like(model.b_h.data) * 1.0
        viz = plot = None

    gt_samples = model.gibbs_sample(n_steps=args.gt_steps,
                                    n_samples=args.n_samples +
                                    args.n_test_samples,
                                    plot=True)
    kmmd = mmd.MMD(mmd.exp_avg_hamming, False)
    gt_samples, gt_samples2 = gt_samples[:args.n_samples], gt_samples[
        args.n_samples:]
    if plot is not None:
        plot("{}/ground_truth.png".format(args.save_dir), gt_samples2)
    opt_stat = kmmd.compute_mmd(gt_samples2, gt_samples)
    print("gt <--> gt log-mmd", opt_stat, opt_stat.log10())

    new_samples = model.gibbs_sample(n_steps=0, n_samples=args.n_test_samples)

    log_mmds = {}
    log_mmds['gibbs'] = []
    ars = {}
    hops = {}
    ess = {}
    times = {}
    chains = {}
    chain = []

    times['gibbs'] = []
    start_time = time.time()
    for i in range(args.n_steps):
        if i % args.print_every == 0:
            stat = kmmd.compute_mmd(new_samples, gt_samples)
            log_stat = stat.log10().item()
            log_mmds['gibbs'].append(log_stat)
            print("gibbs", i, stat, stat.log10())
            times['gibbs'].append(time.time() - start_time)
        new_samples = model.gibbs_sample(new_samples, 1)
        if i % args.subsample == 0:
            if args.ess_statistic == "dims":
                chain.append(new_samples.cpu().numpy()[0][None])
            else:
                xc = new_samples[0][None]
                h = (xc != gt_samples).float().sum(-1)
                chain.append(h.detach().cpu().numpy()[None])

    chain = np.concatenate(chain, 0)
    chains['gibbs'] = chain
    ess['gibbs'] = get_ess(chain, args.burn_in)
    print("ess = {} +/- {}".format(ess['gibbs'].mean(), ess['gibbs'].std()))

    temps = ['bg-1', 'bg-2', 'hb-10-1', 'gwg', 'gwg-3', 'gwg-5']
    for temp in temps:
        if temp == 'dim-gibbs':
            sampler = samplers.PerDimGibbsSampler(args.n_visible)
        elif temp == "rand-gibbs":
            sampler = samplers.PerDimGibbsSampler(args.n_visible, rand=True)
        elif "bg-" in temp:
            block_size = int(temp.split('-')[1])
            sampler = block_samplers.BlockGibbsSampler(args.n_visible,
                                                       block_size)
        elif "hb-" in temp:
            block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]]
            sampler = block_samplers.HammingBallSampler(
                args.n_visible, block_size, hamming_dist)
        elif temp == "gwg":
            sampler = samplers.DiffSampler(args.n_visible,
                                           1,
                                           fixed_proposal=False,
                                           approx=True,
                                           multi_hop=False,
                                           temp=2.)
        elif "gwg-" in temp:
            n_hops = int(temp.split('-')[1])
            sampler = samplers.MultiDiffSampler(args.n_visible,
                                                1,
                                                approx=True,
                                                temp=2.,
                                                n_samples=n_hops)
        else:
            raise ValueError("Invalid sampler...")

        x = model.init_dist.sample((args.n_test_samples, )).to(device)

        log_mmds[temp] = []
        ars[temp] = []
        hops[temp] = []
        times[temp] = []
        chain = []
        cur_time = 0.
        for i in range(args.n_steps):
            # do sampling and time it
            st = time.time()
            xhat = sampler.step(x.detach(), model).detach()
            cur_time += time.time() - st

            # compute hamming dist
            cur_hops = (x != xhat).float().sum(-1).mean().item()

            # update trajectory
            x = xhat

            if i % args.subsample == 0:
                if args.ess_statistic == "dims":
                    chain.append(x.cpu().numpy()[0][None])
                else:
                    xc = x[0][None]
                    h = (xc != gt_samples).float().sum(-1)
                    chain.append(h.detach().cpu().numpy()[None])

            if i % args.viz_every == 0 and plot is not None:
                plot(
                    "/{}/temp_{}_samples_{}.png".format(
                        args.save_dir, temp, i), x)

            if i % args.print_every == 0:
                hard_samples = x
                stat = kmmd.compute_mmd(hard_samples, gt_samples)
                log_stat = stat.log10().item()
                log_mmds[temp].append(log_stat)
                times[temp].append(cur_time)
                hops[temp].append(cur_hops)
                print("temp {}, itr = {}, log-mmd = {:.4f}, hop-dist = {:.4f}".
                      format(temp, i, log_stat, cur_hops))
        chain = np.concatenate(chain, 0)
        ess[temp] = get_ess(chain, args.burn_in)
        chains[temp] = chain
        print("ess = {} +/- {}".format(ess[temp].mean(), ess[temp].std()))

    ess_temps = temps
    plt.clf()
    plt.boxplot([ess[temp] for temp in ess_temps],
                labels=ess_temps,
                showfliers=False)
    plt.savefig("{}/ess.png".format(args.save_dir))

    plt.clf()
    plt.boxplot([
        ess[temp] / times[temp][-1] / (1. - args.burn_in) for temp in ess_temps
    ],
                labels=ess_temps,
                showfliers=False)
    plt.savefig("{}/ess_per_sec.png".format(args.save_dir))

    plt.clf()
    for temp in temps + ['gibbs']:
        plt.plot(log_mmds[temp], label="{}".format(temp))

    plt.legend()
    plt.savefig("{}/results.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(ars[temp], label="{}".format(temp))

    plt.legend()
    plt.savefig("{}/ars.png".format(args.save_dir))

    plt.clf()
    for temp in temps:
        plt.plot(hops[temp], label="{}".format(temp))

    plt.legend()
    plt.savefig("{}/hops.png".format(args.save_dir))

    for temp in temps:
        plt.clf()
        plt.plot(chains[temp][:, 0])
        plt.savefig("{}/trace_{}.png".format(args.save_dir, temp))

    with open("{}/results.pkl".format(args.save_dir), 'wb') as f:
        results = {
            'ess': ess,
            'hops': hops,
            'log_mmds': log_mmds,
            'chains': chains,
            'times': times
        }
        pickle.dump(results, f)