def get_sampler(args): data_dim = np.prod(args.input_size) if args.input_type == "binary": if args.sampler == "gibbs": sampler = samplers.PerDimGibbsSampler(data_dim, rand=False) elif args.sampler == "rand_gibbs": sampler = samplers.PerDimGibbsSampler(data_dim, rand=True) elif args.sampler.startswith("bg-"): block_size = int(args.sampler.split('-')[1]) sampler = block_samplers.BlockGibbsSampler(data_dim, block_size) elif args.sampler.startswith("hb-"): block_size, hamming_dist = [int(v) for v in args.sampler.split('-')[1:]] sampler = block_samplers.HammingBallSampler(data_dim, block_size, hamming_dist) elif args.sampler == "gwg": sampler = samplers.DiffSampler(data_dim, 1, fixed_proposal=False, approx=True, multi_hop=False, temp=2.) elif args.sampler.startswith("gwg-"): n_hops = int(args.sampler.split('-')[1]) sampler = samplers.MultiDiffSampler(data_dim, 1, approx=True, temp=2., n_samples=n_hops) else: raise ValueError("Invalid sampler...") else: if args.sampler == "gibbs": sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=False) elif args.sampler == "rand_gibbs": sampler = samplers.PerDimMetropolisSampler(data_dim, int(args.n_out), rand=True) elif args.sampler == "gwg": sampler = samplers.DiffSamplerMultiDim(data_dim, 1, approx=True, temp=2.) else: raise ValueError("invalid sampler") return sampler
def main(args): makedirs(args.save_dir) torch.manual_seed(args.seed) np.random.seed(args.seed) model = rbm.LatticePottsModel(int(args.dim), int(args.n_out), args.sigma, args.bias) model.to(device) print(device) if args.n_out == 3: plot = lambda p, x: torchvision.utils.save_image( x.view(x.size(0), args.dim, args.dim, 3).transpose(3, 1), p, normalize=False, nrow=int(x.size(0)**.5)) else: plot = None ess_samples = model.init_sample(args.n_samples).to(device) hops = {} ess = {} times = {} chains = {} temps = ['dim-gibbs', 'rand-gibbs', 'gwg'] for temp in temps: if temp == 'dim-gibbs': sampler = samplers.PerDimMetropolisSampler(args.dim**2, args.n_out) elif temp == "rand-gibbs": sampler = samplers.PerDimMetropolisSampler(args.dim**2, args.n_out, rand=True) elif "bg-" in temp: block_size = int(temp.split('-')[1]) sampler = block_samplers.BlockGibbsSampler(model.data_dim, block_size) elif "hb-" in temp: block_size, hamming_dist = [int(v) for v in temp.split('-')[1:]] sampler = block_samplers.HammingBallSampler( model.data_dim, block_size, hamming_dist) elif temp == "gwg": sampler = samplers.DiffSamplerMultiDim(args.dim**2, 1, approx=True, temp=2.) elif "gwg-" in temp: n_hops = int(temp.split('-')[1]) sampler = samplers.MultiDiffSampler(model.data_dim, 1, approx=True, temp=2., n_samples=n_hops) else: raise ValueError("Invalid sampler...") x = model.init_dist.sample((args.n_test_samples, )).to(device) times[temp] = [] hops[temp] = [] chain = [] cur_time = 0. for i in range(args.n_steps): # do sampling and time it st = time.time() xhat = sampler.step(x.detach(), model).detach() cur_time += time.time() - st # compute hamming dist cur_hops = (x != xhat).float().view(x.size(0), -1).sum(-1).mean().item() # update trajectory x = xhat if i % args.subsample == 0: if args.ess_statistic == "dims": chain.append(x.cpu()[0].view(-1).numpy()[None]) else: xc = x[0][None] h = (xc != ess_samples).float().view( ess_samples.size(0), -1).sum(-1) chain.append(h.detach().cpu().numpy()[None]) if i % args.viz_every == 0 and plot is not None: plot( "/{}/temp_{}_samples_{}.png".format( args.save_dir, temp, i), x) if i % args.print_every == 0: times[temp].append(cur_time) hops[temp].append(cur_hops) print("temp {}, itr = {}, hop-dist = {:.4f}".format( temp, i, cur_hops)) chain = np.concatenate(chain, 0) chains[temp] = chain ess[temp] = get_ess(chain, args.burn_in) print("ess = {} +/- {}".format(ess[temp].mean(), ess[temp].std())) ess_temps = temps plt.clf() plt.boxplot([ess[temp] for temp in ess_temps], labels=ess_temps, showfliers=False) plt.savefig("{}/ess.png".format(args.save_dir)) plt.clf() plt.boxplot([ ess[temp] / times[temp][-1] / (1. - args.burn_in) for temp in ess_temps ], labels=ess_temps, showfliers=False) plt.savefig("{}/ess_per_sec.png".format(args.save_dir)) plt.clf() for temp in temps: plt.plot(hops[temp], label="{}".format(temp)) plt.legend() plt.savefig("{}/hops.png".format(args.save_dir)) for temp in temps: plt.clf() plt.plot(chains[temp][:, 0]) plt.savefig("{}/trace_{}.png".format(args.save_dir, temp)) with open("{}/results.pkl".format(args.save_dir), 'wb') as f: results = {'ess': ess, 'hops': hops, 'chains': chains} pickle.dump(results, f)
def main(args): makedirs(args.save_dir) logger = open("{}/log.txt".format(args.save_dir), 'w') def my_print(s): print(s) logger.write(str(s) + '\n') torch.manual_seed(args.seed) np.random.seed(args.seed) # load existing data if args.data == "synthetic": train_loader, test_loader, data, ground_truth_J, ground_truth_h, ground_truth_C = utils.load_synthetic( args.data_file, args.batch_size) dim, n_out = data.size()[1:] ground_truth_J_norm = norm_J(ground_truth_J).to(device) matsave(ground_truth_J.abs().transpose(2, 1).reshape(dim * n_out, dim * n_out), "{}/ground_truth_J.png".format(args.save_dir)) matsave(ground_truth_C, "{}/ground_truth_C.png".format(args.save_dir)) matsave(ground_truth_J_norm, "{}/ground_truth_J_norm.png".format(args.save_dir)) num_ecs = 120 dm_indices = torch.arange(ground_truth_J_norm.size(0)).long() # generate the dataset elif args.data == "PF00018": train_loader, test_loader, data, num_ecs, ground_truth_J_norm, ground_truth_C = utils.load_ingraham(args) dim, n_out = data.size()[1:] ground_truth_J_norm = ground_truth_J_norm.to(device) matsave(ground_truth_C, "{}/ground_truth_C.png".format(args.save_dir)) matsave(ground_truth_J_norm, "{}/ground_truth_dists.png".format(args.save_dir)) dm_indices = torch.arange(ground_truth_J_norm.size(0)).long() else: train_loader, test_loader, data, num_ecs, ground_truth_J_norm, ground_truth_C, dm_indices = utils.load_real_protein(args) dim, n_out = data.size()[1:] ground_truth_J_norm = ground_truth_J_norm.to(device) matsave(ground_truth_C, "{}/ground_truth_C.png".format(args.save_dir)) matsave(ground_truth_J_norm, "{}/ground_truth_dists.png".format(args.save_dir)) if args.model == "lattice_potts": model = rbm.LatticePottsModel(int(args.dim), int(n_out), 0., 0., learn_sigma=True) buffer = model.init_sample(args.buffer_size) if args.model == "dense_potts": model = rbm.DensePottsModel(dim, n_out, learn_J=True, learn_bias=True) buffer = model.init_sample(args.buffer_size) elif args.model == "dense_ising": raise ValueError elif args.model == "mlp": raise ValueError model.to(device) # make G symmetric def get_J(): j = model.J jt = j.transpose(0, 1).transpose(2, 3) return (j + jt) / 2 def get_J_sub(): j = get_J() j_sub = j[dm_indices, :][:, dm_indices] return j_sub if args.sampler == "gibbs": if "potts" in args.model: sampler = samplers.PerDimMetropolisSampler(dim, int(n_out), rand=False) else: sampler = samplers.PerDimGibbsSampler(dim, rand=False) elif args.sampler == "plm": sampler = samplers.PerDimMetropolisSampler(dim, int(n_out), rand=False) elif args.sampler == "rand_gibbs": if "potts" in args.model: sampler = samplers.PerDimMetropolisSampler(dim, int(n_out), rand=True) else: sampler = samplers.PerDimGibbsSampler(dim, rand=True) elif args.sampler == "gwg": if "potts" in args.model: sampler = samplers.DiffSamplerMultiDim(dim, 1, approx=True, temp=2.) else: sampler = samplers.DiffSampler(dim, 1, approx=True, fixed_proposal=False, temp=2.) else: assert "gwg-" in args.sampler n_hop = int(args.sampler.split('-')[1]) if "potts" in args.model: raise ValueError else: sampler = samplers.MultiDiffSampler(model.data_dim, 1, approx=True, temp=2., n_samples=n_hop) my_print(device) my_print(model) my_print(buffer.size()) my_print(sampler) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # load ckpt if args.ckpt_path is not None: d = torch.load(args.ckpt_path) model.load_state_dict(d['model']) optimizer.load_state_dict(d['optimizer']) sampler.load_state_dict(d['sampler']) # mask matrix for PLM L, D = model.J.size(0), model.J.size(2) num_node = L * D J_mask = torch.ones((num_node, num_node)).to(device) for i in range(L): J_mask[D * i:D * i + D, D * i:D * i + D] = 0 itr = 0 sq_errs = [] rmses = [] all_inds = list(range(args.buffer_size)) while itr < args.n_iters: for x in train_loader: if args.data == "synthetic": x = x[0].to(device) weights = torch.ones((x.size(0),)).to(device) else: weights = x[1].to(device) if args.unweighted: weights = torch.ones_like(weights) x = x[0].to(device) if args.sampler == "plm": plm_J = model.J.transpose(2, 1).reshape(dim * n_out, dim * n_out) logits = torch.matmul(x.view(x.size(0), -1), plm_J * J_mask) + model.bias.view(-1)[None] x_inds = (torch.arange(x.size(-1))[None, None].to(x.device) * x).sum(-1) cross_entropy = nn.functional.cross_entropy( input=logits.reshape((-1, D)), target=x_inds.view(-1).long(), reduce=False) cross_entropy = torch.sum(cross_entropy.reshape((-1, L)), -1) loss = (cross_entropy * weights).mean() else: buffer_inds = np.random.choice(all_inds, args.batch_size, replace=False) x_fake = buffer[buffer_inds].to(device) for k in range(args.sampling_steps): x_fake = sampler.step(x_fake.detach(), model).detach() buffer[buffer_inds] = x_fake.detach().cpu() logp_real = (model(x).squeeze() * weights).mean() logp_fake = model(x_fake).squeeze().mean() obj = logp_real - logp_fake loss = -obj # add l1 reg loss += args.l1 * norm_J(get_J()).sum() optimizer.zero_grad() loss.backward() optimizer.step() if itr % args.print_every == 0: if args.sampler == "plm": my_print("({}) loss = {:.4f}".format(itr, loss.item())) else: my_print("({}) log p(real) = {:.4f}, log p(fake) = {:.4f}, diff = {:.4f}, hops = {:.4f}".format(itr, logp_real.item(), logp_fake.item(), obj.item(), sampler._hops)) sq_err = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).sum() rmse = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).mean().sqrt() inds = torch.triu_indices(ground_truth_C.size(0), ground_truth_C.size(1), 1) C_inds = ground_truth_C[inds[0], inds[1]] J_inds = norm_J(get_J_sub())[inds[0], inds[1]] J_inds_sorted = torch.sort(J_inds, descending=True).indices C_inds_sorted = C_inds[J_inds_sorted] C_cumsum = C_inds_sorted.cumsum(0) arange = torch.arange(C_cumsum.size(0)) + 1 acc_at = C_cumsum.float() / arange.float() my_print("\t err^2 = {:.4f}, rmse = {:.4f}, acc @ 50 = {:.4f}, acc @ 75 = {:.4f}, acc @ 100 = {:.4f}".format(sq_err, rmse, acc_at[50], acc_at[75], acc_at[100])) logger.flush() if itr % args.viz_every == 0: sq_err = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).sum() rmse = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).mean().sqrt() sq_errs.append(sq_err.item()) plt.clf() plt.plot(sq_errs, label="sq_err") plt.legend() plt.savefig("{}/sq_err.png".format(args.save_dir)) rmses.append(rmse.item()) plt.clf() plt.plot(rmses, label="rmse") plt.legend() plt.savefig("{}/rmse.png".format(args.save_dir)) matsave(get_J_sub().abs().transpose(2, 1).reshape(dm_indices.size(0) * n_out, dm_indices.size(0) * n_out), "{}/model_J_{}_sub.png".format(args.save_dir, itr)) matsave(norm_J(get_J_sub()), "{}/model_J_norm_{}_sub.png".format(args.save_dir, itr)) matsave(get_J().abs().transpose(2, 1).reshape(dim * n_out, dim * n_out), "{}/model_J_{}.png".format(args.save_dir, itr)) matsave(norm_J(get_J()), "{}/model_J_norm_{}.png".format(args.save_dir, itr)) inds = torch.triu_indices(ground_truth_C.size(0), ground_truth_C.size(1), 1) C_inds = ground_truth_C[inds[0], inds[1]] J_inds = norm_J(get_J_sub())[inds[0], inds[1]] J_inds_sorted = torch.sort(J_inds, descending=True).indices C_inds_sorted = C_inds[J_inds_sorted] C_cumsum = C_inds_sorted.cumsum(0) arange = torch.arange(C_cumsum.size(0)) + 1 acc_at = C_cumsum.float() / arange.float() plt.clf() plt.plot(acc_at[:num_ecs].detach().cpu().numpy()) plt.savefig("{}/acc_at_{}.png".format(args.save_dir, itr)) if itr % args.ckpt_every == 0: my_print("Saving checkpoint to {}/ckpt.pt".format(args.save_dir)) torch.save({ "model": model.state_dict(), "optimizer": optimizer.state_dict(), "sampler": sampler.state_dict() }, "{}/ckpt.pt".format(args.save_dir)) itr += 1 if itr > args.n_iters: sq_err = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).sum() rmse = ((ground_truth_J_norm - norm_J(get_J_sub())) ** 2).mean().sqrt() with open("{}/sq_err.txt".format(args.save_dir), 'w') as f: f.write(str(sq_err)) with open("{}/rmse.txt".format(args.save_dir), 'w') as f: f.write(str(rmse)) torch.save({ "model": model.state_dict(), "optimizer": optimizer.state_dict(), "sampler": sampler.state_dict() }, "{}/ckpt.pt".format(args.save_dir)) quit()
def main(args): makedirs(args.save_dir) logger = open("{}/log.txt".format(args.save_dir), 'w') def my_print(s): print(s) logger.write(str(s) + '\n') torch.manual_seed(args.seed) np.random.seed(args.seed) # load existing data if args.model == "lattice_potts": model = rbm.LatticePottsModel(int(args.dim), int(args.n_out), 0., 0., learn_sigma=True) if args.model == "dense_potts": model = rbm.DensePottsModel(args.dim, args.n_out, learn_J=True, learn_bias=True) else: raise ValueError model.to(device) if args.sampler == "gibbs": sampler = samplers.PerDimMetropolisSampler(args.dim, int(args.n_out), rand=False) elif args.sampler == "rand_gibbs": sampler = samplers.PerDimMetropolisSampler(args.dim, int(args.n_out), rand=True) elif args.sampler == "gwg": sampler = samplers.DiffSamplerMultiDim(args.dim, 1, approx=True, temp=2.) else: raise ValueError my_print(device) my_print(model) my_print(sampler) # load ckpt my_print("Loading...") if args.ckpt_path is not None: d = torch.load(args.ckpt_path) model.load_state_dict(d['model']) my_print("Loaded!") betas = np.linspace(0., 1., args.n_iters) samples = model.init_sample(args.n_samples) log_w = torch.zeros((args.n_samples, )).to(device) log_w += model.bias.logsumexp(-1).sum() logZs = [] for itr, beta_k in enumerate(betas): if itr == 0: continue # skip 0 beta_km1 = betas[itr - 1] # udpate importance weights with torch.no_grad(): log_w = log_w + model(samples, beta=beta_k) - model( samples, beta_km1) # update samples model_k = lambda x: model(x, beta=beta_k) for d in range(args.steps_per_iter): samples = sampler.step(samples.detach(), model_k).detach() if itr % args.print_every == 0: logZ = log_w.logsumexp(0) - np.log(args.n_samples) logZs.append(logZ.item()) my_print("({}) beta = {}, log Z = {:.4f}".format( itr, beta_k, logZ.item())) logger.flush() if itr % args.viz_every == 0: plt.clf() plt.plot(logZs, label="log(Z)") plt.legend() plt.savefig("{}/logZ.png".format(args.save_dir)) logZ_final = log_w.logsumexp(0) - np.log(args.n_samples) my_print("Final log(Z) = {:.4f}".format(logZ_final))
def main(args): makedirs(args.save_dir) logger = open("{}/log.txt".format(args.save_dir), 'w') def my_print(s): print(s) logger.write(str(s) + '\n') torch.manual_seed(args.seed) np.random.seed(args.seed) # load existing data if args.data == "mnist" or args.data_file is not None: train_loader, test_loader, plot, viz = utils.get_data(args) # generate the dataset else: data, data_model = utils.generate_data(args) my_print( "we have created your data, but what have you done for me lately?????" ) with open("{}/data.pkl".format(args.save_dir), 'wb') as f: pickle.dump(data, f) if args.data_model == "er_ising": ground_truth_J = data_model.J.detach().cpu() with open("{}/J.pkl".format(args.save_dir), 'wb') as f: pickle.dump(ground_truth_J, f) quit() if args.model == "lattice_potts": model = rbm.LatticePottsModel(int(args.dim), int(args.n_state), 0., 0., learn_sigma=True) buffer = model.init_sample(args.buffer_size) elif args.model == "lattice_ising": model = rbm.LatticeIsingModel(int(args.dim), 0., 0., learn_sigma=True) buffer = model.init_sample(args.buffer_size) elif args.model == "lattice_ising_3d": model = rbm.LatticeIsingModel(int(args.dim), .2, learn_G=True, lattice_dim=3) ground_truth_J = model.J.clone().to(device) model.G.data = torch.randn_like(model.G.data) * .01 model.sigma.data = torch.ones_like(model.sigma.data) buffer = model.init_sample(args.buffer_size) plt.clf() plt.matshow(ground_truth_J.detach().cpu().numpy()) plt.savefig("{}/ground_truth.png".format(args.save_dir)) elif args.model == "lattice_ising_2d": model = rbm.LatticeIsingModel(int(args.dim), args.sigma, learn_G=True, lattice_dim=2) ground_truth_J = model.J.clone().to(device) model.G.data = torch.randn_like(model.G.data) * .01 model.sigma.data = torch.ones_like(model.sigma.data) buffer = model.init_sample(args.buffer_size) plt.clf() plt.matshow(ground_truth_J.detach().cpu().numpy()) plt.savefig("{}/ground_truth.png".format(args.save_dir)) elif args.model == "er_ising": model = rbm.ERIsingModel(int(args.dim), 2, learn_G=True) model.G.data = torch.randn_like(model.G.data) * .01 buffer = model.init_sample(args.buffer_size) with open(args.graph_file, 'rb') as f: ground_truth_J = pickle.load(f) plt.clf() plt.matshow(ground_truth_J.detach().cpu().numpy()) plt.savefig("{}/ground_truth.png".format(args.save_dir)) ground_truth_J = ground_truth_J.to(device) elif args.model == "rbm": model = rbm.BernoulliRBM(args.dim, args.n_hidden) buffer = model.init_dist.sample((args.buffer_size, )) elif args.model == "dense_potts": raise ValueError elif args.model == "dense_ising": raise ValueError elif args.model == "mlp": raise ValueError model.to(device) buffer = buffer.to(device) # make G symmetric def get_J(): j = model.J return (j + j.t()) / 2 if args.sampler == "gibbs": if "potts" in args.model: sampler = samplers.PerDimMetropolisSampler(model.data_dim, int(args.n_state), rand=False) else: sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=False) elif args.sampler == "rand_gibbs": if "potts" in args.model: sampler = samplers.PerDimMetropolisSampler(model.data_dim, int(args.n_state), rand=True) else: sampler = samplers.PerDimGibbsSampler(model.data_dim, rand=True) elif args.sampler == "gwg": if "potts" in args.model: sampler = samplers.DiffSamplerMultiDim(model.data_dim, 1, approx=True, temp=2.) else: sampler = samplers.DiffSampler(model.data_dim, 1, approx=True, fixed_proposal=False, temp=2.) else: assert "gwg-" in args.sampler n_hop = int(args.sampler.split('-')[1]) if "potts" in args.model: raise ValueError else: sampler = samplers.MultiDiffSampler(model.data_dim, 1, approx=True, temp=2., n_samples=n_hop) my_print(device) my_print(model) my_print(buffer.size()) my_print(sampler) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) itr = 0 sigmas = [] sq_errs = [] rmses = [] while itr < args.n_iters: for x in train_loader: x = x[0].to(device) for k in range(args.sampling_steps): buffer = sampler.step(buffer.detach(), model).detach() logp_real = model(x).squeeze().mean() logp_fake = model(buffer).squeeze().mean() obj = logp_real - logp_fake loss = -obj loss += args.l1 * get_J().abs().sum() optimizer.zero_grad() loss.backward() optimizer.step() model.G.data *= (1. - torch.eye(model.G.data.size(0))).to(model.G) if itr % args.print_every == 0: my_print( "({}) log p(real) = {:.4f}, log p(fake) = {:.4f}, diff = {:.4f}, hops = {:.4f}" .format(itr, logp_real.item(), logp_fake.item(), obj.item(), sampler._hops)) if args.model in ("lattice_potts", "lattice_ising"): my_print( "\tsigma true = {:.4f}, current sigma = {:.4f}".format( args.sigma, model.sigma.data.item())) else: sq_err = ((ground_truth_J - get_J())**2).sum() rmse = ((ground_truth_J - get_J())**2).mean().sqrt() my_print("\t err^2 = {:.4f}, rmse = {:.4f}".format( sq_err, rmse)) print(ground_truth_J) print(get_J()) if itr % args.viz_every == 0: if args.model in ("lattice_potts", "lattice_ising"): sigmas.append(model.sigma.data.item()) plt.clf() plt.plot(sigmas, label="model") plt.plot([args.sigma for s in sigmas], label="gt") plt.legend() plt.savefig("{}/sigma.png".format(args.save_dir)) else: sq_err = ((ground_truth_J - get_J())**2).sum() sq_errs.append(sq_err.item()) plt.clf() plt.plot(sq_errs, label="sq_err") plt.legend() plt.savefig("{}/sq_err.png".format(args.save_dir)) rmse = ((ground_truth_J - get_J())**2).mean().sqrt() rmses.append(rmse.item()) plt.clf() plt.plot(rmses, label="rmse") plt.legend() plt.savefig("{}/rmse.png".format(args.save_dir)) plt.clf() plt.matshow(get_J().detach().cpu().numpy()) plt.savefig("{}/model_{}.png".format(args.save_dir, itr)) plot("{}/data_{}.png".format(args.save_dir, itr), x.detach().cpu()) plot("{}/buffer_{}.png".format(args.save_dir, itr), buffer[:args.batch_size].detach().cpu()) itr += 1 if itr > args.n_iters: if args.model in ("lattice_potts", "lattice_ising"): final_sigma = model.sigma.data.item() with open("{}/sigma.txt".format(args.save_dir), 'w') as f: f.write(str(final_sigma)) else: sq_err = ((ground_truth_J - get_J())**2).sum().item() rmse = ((ground_truth_J - get_J())**2).mean().sqrt().item() with open("{}/sq_err.txt".format(args.save_dir), 'w') as f: f.write(str(sq_err)) with open("{}/rmse.txt".format(args.save_dir), 'w') as f: f.write(str(rmse)) quit()