def get_sampler(args): data_dim = np.prod(args.input_size) if args.input_type == "binary": if args.sampler == "gibbs": sampler = samplers.PerDimGibbsSampler(data_dim, rand=False) elif args.sampler == "rand_gibbs": sampler = samplers.PerDimGibbsSampler(data_dim, rand=True) elif args.sampler.startswith("bg-"): block_size = int(args.sampler.split('-')[1]) sampler = block_samplers.BlockGibbsSampler(data_dim, block_size) elif args.sampler.startswith("hb-"): block_size, hamming_dist = [int(v) for v in args.sampler.split('-')[1:]] sampler = block_samplers.HammingBallSampler(data_dim, block_size, hamming_dist) elif args.sampler == "gwg": sampler = samplers.DiffSampler(data_dim, 1, fixed_proposal=False, approx=True, multi_hop=False, temp=2.) elif args.sampler.startswith("gwg-"): n_hops = int(args.sampler.split('-')[1]) sampler = samplers.MultiDiffSampler(data_dim, 1, approx=True, temp=2., n_samples=n_hops) else: raise ValueError("Invalid sampler...") else: if args.sampler == "gibbs": sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=False) elif args.sampler == "rand_gibbs": sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=True) elif args.sampler == "gwg": sampler = samplers.DiffSamplerMultiDim(data_dim, 1, approx=True, temp=2.) else: raise ValueError("invalid sampler") return sampler
def main(args): makedirs(args.save_dir) torch.manual_seed(args.seed) np.random.seed(args.seed) model = rbm.LatticePottsModel(int(args.dim), int(args.n_out), args.sigma, args.bias) model.to(device) print(device) if args.n_out == 3: plot = lambda p, x: torchvision.utils.save_image( x.view(x.size(0), args.dim, args.dim, 3).transpose(3, 1), p, normalize=False, nrow=int(x.size(0)**.5)) else: plot = None ess_samples = model.init_sample(args.n_samples).to(device) hops = {} ess = {} times = {} chains = {} temps = ['dim-gibbs', 'rand-gibbs', 'gwg'] for temp in temps: if temp == 'dim-gibbs': sampler = samplers.PerDimMetropolisSampler(args.dim**2, args.n_out) elif temp == "rand-gibbs": sampler = samplers.PerDimMetropolisSampler(args.dim**2, args.n_out, rand=True) elif "bg-" in temp: block_size = int(temp.split('-')[1]) sampler = block_samplers.BlockGibbsSampler(model.data_dim, block_size) elif "hb-" in temp: block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]] sampler = block_samplers.HammingBallSampler( model.data_dim, block_size, hamming_dist) elif temp == "gwg": sampler = samplers.DiffSamplerMultiDim(args.dim**2, 1, approx=True, temp=2.) elif "gwg-" in temp: n_hops = int(temp.split('-')[1]) sampler = samplers.MultiDiffSampler(model.data_dim, 1, approx=True, temp=2., n_samples=n_hops) else: raise ValueError("Invalid sampler...") x = model.init_dist.sample((args.n_test_samples, )).to(device) times[temp] = [] hops[temp] = [] chain = [] cur_time = 0. for i in range(args.n_steps): # do sampling and time it st = time.time() xhat = sampler.step(x.detach(), model).detach() cur_time += time.time() - st # compute hamming dist cur_hops = (x != xhat).float().view(x.size(0), -1).sum(-1).mean().item() # update trajectory x = xhat if i % args.subsample == 0: if args.ess_statistic == "dims": chain.append(x.cpu()[0].view(-1).numpy()[None]) else: xc = x[0][None] h = (xc != ess_samples).float().view( ess_samples.size(0), -1).sum(-1) chain.append(h.detach().cpu().numpy()[None]) if i % args.viz_every == 0 and plot is not None: plot( "/{}/temp_{}_samples_{}.png".format( args.save_dir, temp, i), x) if i % args.print_every == 0: times[temp].append(cur_time) hops[temp].append(cur_hops) print("temp {}, itr = {}, hop-dist = {:.4f}".format( temp, i, cur_hops)) chain = np.concatenate(chain, 0) chains[temp] = chain ess[temp] = get_ess(chain, args.burn_in) print("ess = {} +/- {}".format(ess[temp].mean(), ess[temp].std())) ess_temps = temps plt.clf() plt.boxplot([ess[temp] for temp in ess_temps], labels=ess_temps, showfliers=False) plt.savefig("{}/ess.png".format(args.save_dir)) plt.clf() plt.boxplot([ ess[temp] / times[temp][-1] / (1. - args.burn_in) for temp in ess_temps ], labels=ess_temps, showfliers=False) plt.savefig("{}/ess_per_sec.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(hops[temp], label="{}".format(temp)) plt.legend() plt.savefig("{}/hops.png".format(args.save_dir)) for temp in temps: plt.clf() plt.plot(chains[temp][:, 0]) plt.savefig("{}/trace_{}.png".format(args.save_dir, temp)) with open("{}/results.pkl".format(args.save_dir), 'wb') as f: results = {'ess': ess, 'hops': hops, 'chains': chains} pickle.dump(results, f)
def main(args): makedirs("{}/sources".format(args.save_dir)) torch.manual_seed(args.seed) np.random.seed(args.seed) W = args.W_init_sigma * torch.randn((args.K,)) W0 = args.W_init_sigma * torch.randn((1,)) p = args.X_keep_prob * torch.ones((args.K,)) v = args.X0_mean * torch.ones((args.K,)) model = fhmm.FHMM(args.N, args.K, W, W0, args.obs_sigma, p, v, alt_logpx=args.alt) model.to(device) print("device is", device) # generate data Xgt = model.sample_X(1) p_y_given_Xgt = model.p_y_given_x(Xgt) mu = p_y_given_Xgt.loc mu_true = mu[0] plt.clf() plt.plot(mu_true.detach().cpu().numpy(), label="mean") ygt = p_y_given_Xgt.sample()[0] plt.plot(ygt.detach().cpu().numpy(), label='sample') plt.legend() plt.savefig("{}/data.png".format(args.save_dir)) ygt = ygt.to(device) for k in range(args.K): plt.clf() plt.plot(Xgt[0, :, k].detach().cpu().numpy()) plt.savefig("{}/sources/x_{}.png".format(args.save_dir, k)) logp_joint_real = model.log_p_joint(ygt, Xgt).item() print("joint likelihood of real data is {}".format(logp_joint_real)) log_joints = {} diffs = {} times = {} recons = {} ars = {} hops = {} phops = {} mus = {} dim = args.K * args.N x_init = model.sample_X(args.n_test_samples).to(device) samp_model = lambda _x: model.log_p_joint(ygt, _x) temps = ['bg-1', 'bg-2', 'hb-10-1', 'gwg', 'gwg-3', 'gwg-5'] for temp in temps: makedirs("{}/{}".format(args.save_dir, temp)) if temp == 'dim-gibbs': sampler = samplers.PerDimGibbsSampler(dim) elif temp == "rand-gibbs": sampler = samplers.PerDimGibbsSampler(dim, rand=True) elif "bg-" in temp: block_size = int(temp.split('-')[1]) sampler = block_samplers.BlockGibbsSampler(dim, block_size) elif "hb-" in temp: block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]] sampler = block_samplers.HammingBallSampler(dim, block_size, hamming_dist) elif temp == "gwg": sampler = samplers.DiffSampler(dim, 1, fixed_proposal=False, approx=True, multi_hop=False, temp=2.) elif "gwg-" in temp: n_hops = int(temp.split('-')[1]) sampler = samplers.MultiDiffSampler(dim, 1, approx=True, temp=2., n_samples=n_hops) else: raise ValueError("Invalid sampler...") x = x_init.clone().view(x_init.size(0), -1) diffs[temp] = [] log_joints[temp] = [] ars[temp] = [] hops[temp] = [] phops[temp] = [] recons[temp] = [] start_time = time.time() for i in range(args.n_steps + 1): if args.anneal is None: sm = samp_model else: s = np.linspace(args.anneal, args.obs_sigma, args.n_steps + 1)[i] sm = lambda _x: model.log_p_joint(ygt, _x, sigma=s) xhat = sampler.step(x.detach(), sm).detach() # compute hamming dist cur_hops = (x != xhat).float().sum(-1).mean().item() # update trajectory x = xhat if i % 1000 == 0: p_y_given_x = model.p_y_given_x(x) mu = p_y_given_x.loc plt.clf() plt.plot(mu_true.detach().cpu().numpy(), label="true") plt.plot(mu[0].detach().cpu().numpy() + .01, label='mu0') plt.plot(mu[1].detach().cpu().numpy() - .01, label='mu1') plt.legend() plt.savefig("{}/{}/mean_{}.png".format(args.save_dir, temp, i)) mus[temp] = mu[0].detach().cpu().numpy() if i % 10 == 0: p_y_given_x = model.p_y_given_x(x) mu = p_y_given_x.loc err = ((mu - ygt[None]) ** 2).sum(1).mean() recons[temp].append(err.item()) log_j = model.log_p_joint(ygt, x) diff = (x.view(x.size(0), args.N, args.K) != Xgt).float().view(x.size(0), -1).mean(1) log_joints[temp].append(log_j.mean().item()) diffs[temp].append(diff.mean().item()) hops[temp].append(cur_hops) print("temp {}, itr = {}, log-joint = {:.4f}, " "hop-dist = {:.4f}, recons = {:.4f}".format(temp, i, log_j.mean().item(), cur_hops, err.item())) for k in range(args.K): plt.clf() xr = x.view(x.size(0), args.N, args.K) plt.plot(xr[0, :, k].detach().cpu().numpy()) plt.savefig("{}/{}/source_{}.png".format(args.save_dir, temp, k)) times[temp] = time.time() - start_time plt.clf() for temp in temps: plt.plot(log_joints[temp], label=temp) plt.plot([logp_joint_real for _ in log_joints[temp]], label="true") plt.legend() plt.savefig("{}/joints.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(recons[temp], label=temp) plt.legend() plt.savefig("{}/recons.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(diffs[temp], label=temp) plt.legend() plt.savefig("{}/errs.png".format(args.save_dir)) plt.clf() for i, temp in enumerate(temps): plt.plot(mus[temp] + float(i) * .01, label=temp) plt.plot(mu_true.detach().cpu().numpy(), label="true") plt.legend() plt.savefig("{}/mean.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(hops[temp], label="{}".format(temp)) plt.legend() plt.savefig("{}/hops.png".format(args.save_dir)) with open("{}/results.pkl".format(args.save_dir), 'wb') as f: results = { 'hops': hops, 'recons': recons, 'joints': log_joints, } pickle.dump(results, f)
def main(args): makedirs(args.save_dir) torch.manual_seed(args.seed) np.random.seed(args.seed) model = rbm.BernoulliRBM(args.n_visible, args.n_hidden) model.to(device) print(device) if args.data == "mnist": assert args.n_visible == 784 train_loader, test_loader, plot, viz = utils.get_data(args) init_data = [] for x, _ in train_loader: init_data.append(x) init_data = torch.cat(init_data, 0) init_mean = init_data.mean(0).clamp(.01, .99) model = rbm.BernoulliRBM(args.n_visible, args.n_hidden, data_mean=init_mean) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.rbm_lr) # train! itr = 0 for x, _ in train_loader: x = x.to(device) xhat = model.gibbs_sample(v=x, n_steps=args.cd) d = model.logp_v_unnorm(x) m = model.logp_v_unnorm(xhat) obj = d - m loss = -obj.mean() optimizer.zero_grad() loss.backward() optimizer.step() if itr % args.print_every == 0: print( "{} | log p(data) = {:.4f}, log p(model) = {:.4f}, diff = {:.4f}" .format(itr, d.mean(), m.mean(), (d - m).mean())) else: model.W.data = torch.randn_like(model.W.data) * (.05**.5) model.b_v.data = torch.randn_like(model.b_v.data) * 1.0 model.b_h.data = torch.randn_like(model.b_h.data) * 1.0 viz = plot = None gt_samples = model.gibbs_sample(n_steps=args.gt_steps, n_samples=args.n_samples + args.n_test_samples, plot=True) kmmd = mmd.MMD(mmd.exp_avg_hamming, False) gt_samples, gt_samples2 = gt_samples[:args.n_samples], gt_samples[ args.n_samples:] if plot is not None: plot("{}/ground_truth.png".format(args.save_dir), gt_samples2) opt_stat = kmmd.compute_mmd(gt_samples2, gt_samples) print("gt <--> gt log-mmd", opt_stat, opt_stat.log10()) new_samples = model.gibbs_sample(n_steps=0, n_samples=args.n_test_samples) log_mmds = {} log_mmds['gibbs'] = [] ars = {} hops = {} ess = {} times = {} chains = {} chain = [] times['gibbs'] = [] start_time = time.time() for i in range(args.n_steps): if i % args.print_every == 0: stat = kmmd.compute_mmd(new_samples, gt_samples) log_stat = stat.log10().item() log_mmds['gibbs'].append(log_stat) print("gibbs", i, stat, stat.log10()) times['gibbs'].append(time.time() - start_time) new_samples = model.gibbs_sample(new_samples, 1) if i % args.subsample == 0: if args.ess_statistic == "dims": chain.append(new_samples.cpu().numpy()[0][None]) else: xc = new_samples[0][None] h = (xc != gt_samples).float().sum(-1) chain.append(h.detach().cpu().numpy()[None]) chain = np.concatenate(chain, 0) chains['gibbs'] = chain ess['gibbs'] = get_ess(chain, args.burn_in) print("ess = {} +/- {}".format(ess['gibbs'].mean(), ess['gibbs'].std())) temps = ['bg-1', 'bg-2', 'hb-10-1', 'gwg', 'gwg-3', 'gwg-5'] for temp in temps: if temp == 'dim-gibbs': sampler = samplers.PerDimGibbsSampler(args.n_visible) elif temp == "rand-gibbs": sampler = samplers.PerDimGibbsSampler(args.n_visible, rand=True) elif "bg-" in temp: block_size = int(temp.split('-')[1]) sampler = block_samplers.BlockGibbsSampler(args.n_visible, block_size) elif "hb-" in temp: block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]] sampler = block_samplers.HammingBallSampler( args.n_visible, block_size, hamming_dist) elif temp == "gwg": sampler = samplers.DiffSampler(args.n_visible, 1, fixed_proposal=False, approx=True, multi_hop=False, temp=2.) elif "gwg-" in temp: n_hops = int(temp.split('-')[1]) sampler = samplers.MultiDiffSampler(args.n_visible, 1, approx=True, temp=2., n_samples=n_hops) else: raise ValueError("Invalid sampler...") x = model.init_dist.sample((args.n_test_samples, )).to(device) log_mmds[temp] = [] ars[temp] = [] hops[temp] = [] times[temp] = [] chain = [] cur_time = 0. for i in range(args.n_steps): # do sampling and time it st = time.time() xhat = sampler.step(x.detach(), model).detach() cur_time += time.time() - st # compute hamming dist cur_hops = (x != xhat).float().sum(-1).mean().item() # update trajectory x = xhat if i % args.subsample == 0: if args.ess_statistic == "dims": chain.append(x.cpu().numpy()[0][None]) else: xc = x[0][None] h = (xc != gt_samples).float().sum(-1) chain.append(h.detach().cpu().numpy()[None]) if i % args.viz_every == 0 and plot is not None: plot( "/{}/temp_{}_samples_{}.png".format( args.save_dir, temp, i), x) if i % args.print_every == 0: hard_samples = x stat = kmmd.compute_mmd(hard_samples, gt_samples) log_stat = stat.log10().item() log_mmds[temp].append(log_stat) times[temp].append(cur_time) hops[temp].append(cur_hops) print("temp {}, itr = {}, log-mmd = {:.4f}, hop-dist = {:.4f}". format(temp, i, log_stat, cur_hops)) chain = np.concatenate(chain, 0) ess[temp] = get_ess(chain, args.burn_in) chains[temp] = chain print("ess = {} +/- {}".format(ess[temp].mean(), ess[temp].std())) ess_temps = temps plt.clf() plt.boxplot([ess[temp] for temp in ess_temps], labels=ess_temps, showfliers=False) plt.savefig("{}/ess.png".format(args.save_dir)) plt.clf() plt.boxplot([ ess[temp] / times[temp][-1] / (1. - args.burn_in) for temp in ess_temps ], labels=ess_temps, showfliers=False) plt.savefig("{}/ess_per_sec.png".format(args.save_dir)) plt.clf() for temp in temps + ['gibbs']: plt.plot(log_mmds[temp], label="{}".format(temp)) plt.legend() plt.savefig("{}/results.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(ars[temp], label="{}".format(temp)) plt.legend() plt.savefig("{}/ars.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(hops[temp], label="{}".format(temp)) plt.legend() plt.savefig("{}/hops.png".format(args.save_dir)) for temp in temps: plt.clf() plt.plot(chains[temp][:, 0]) plt.savefig("{}/trace_{}.png".format(args.save_dir, temp)) with open("{}/results.pkl".format(args.save_dir), 'wb') as f: results = { 'ess': ess, 'hops': hops, 'log_mmds': log_mmds, 'chains': chains, 'times': times } pickle.dump(results, f)