def do_test_per_param_optim(self, fixed_param, free_param): pyro._param_store._clear_cache() def model(): prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5)) mu_latent = pyro.sample("mu_latent", prior_dist) x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5)) pyro.observe("obs", x_dist, self.data) return mu_latent def guide(): mu_q = pyro.param("mu_q", Variable(torch.zeros(1), requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", Variable(torch.zeros(1), requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("mu_latent", DiagNormal(mu_q, sig_q)) def optim_params(param_name, param): if param_name == fixed_param: return {'lr': 0.00} elif param_name == free_param: return {'lr': 0.01} kl_optim = KL_QP(model, guide, pyro.optim(torch.optim.Adam, optim_params)) for k in range(3): kl_optim.step() free_param_unchanged = torch.equal( pyro.param(free_param).data, torch.zeros(1)) fixed_param_unchanged = torch.equal( pyro.param(fixed_param).data, torch.zeros(1)) passed_test = fixed_param_unchanged and not free_param_unchanged self.assertTrue(passed_test)
def test_elbo_nonreparametrized(self): pyro._param_store._clear_cache() def model(): p_latent = pyro.sample("p_latent", Beta(self.alpha0, self.beta0)) x_dist = Bernoulli(p_latent) pyro.map_data(self.data, lambda i, x: pyro.observe("obs", x_dist, x), batch_size=2) return p_latent def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", Beta(alpha_q, beta_q)) pyro.map_data(self.data, lambda i, x: None, batch_size=2) kl_optim = KL_QP( model, guide, pyro.optim(torch.optim.Adam, { "lr": .001, "betas": (0.97, 0.999) })) for k in range(6001): kl_optim.step() # if k%1000==0: # print "alpha_q", torch.exp(pyro.param("alpha_q_log")).data.numpy()[0] # print "beta_q", torch.exp(pyro.param("beta_q_log")).data.numpy()[0] # # print "alpha_n", self.alpha_n.data.numpy()[0] # print "beta_n", self.beta_n.data.numpy()[0] # print "alpha_0", self.alpha0.data.numpy()[0] # print "beta_0", self.beta0.data.numpy()[0] alpha_error = torch.abs(pyro.param("alpha_q_log") - self.log_alpha_n).data.cpu().numpy()[0] beta_error = torch.abs(pyro.param("beta_q_log") - self.log_beta_n).data.cpu().numpy()[0] # print "alpha_error", alpha_error # print "beta_error", beta_error self.assertEqual(0.0, alpha_error, prec=0.05) self.assertEqual(0.0, beta_error, prec=0.05)
def do_elbo_test(self, reparametrized, n_steps): pyro._param_store._clear_cache() def model(): mu_latent = pyro.sample( "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5))) x_dist = LogNormal(mu_latent, torch.pow(self.tau, -0.5)) pyro.observe("obs0", x_dist, self.data[0]) pyro.observe("obs1", x_dist, self.data[1]) return mu_latent def guide(): mu_q_log = pyro.param( "mu_q_log", Variable(self.log_mu_n.data + 0.17, requires_grad=True)) tau_q_log = pyro.param( "tau_q_log", Variable(self.log_tau_n.data - 0.143, requires_grad=True)) mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log) q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5)) q_dist.reparametrized = reparametrized pyro.sample("mu_latent", q_dist) kl_optim = KL_QP( model, guide, pyro.optim(torch.optim.Adam, { "lr": .0005, "betas": (0.96, 0.999) })) for k in range(n_steps): kl_optim.step() # if k%1000==0: # print "log_mu_q", pyro.param("mu_q_log").data.numpy()[0] # print "log_tau_q", pyro.param("tau_q_log").data.numpy()[0] # print "log_mu_n", self.log_mu_n.data.numpy()[0] # print "log_tau_n", self.log_tau_n.data.numpy()[0] mu_error = torch.abs(pyro.param("mu_q_log") - self.log_mu_n).data.cpu().numpy()[0] tau_error = torch.abs(pyro.param("tau_q_log") - self.log_tau_n).data.cpu().numpy()[0] # print "mu_error", mu_error # print "tau_error", tau_error self.assertEqual(0.0, mu_error, prec=0.07) self.assertEqual(0.0, tau_error, prec=0.07)
def test_elbo_with_transformed_distribution(self): pyro._param_store._clear_cache() def model(): mu_latent = pyro.sample( "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5))) unit_normal = dist.DiagNormal(Variable(torch.zeros(1, 1)), Variable(torch.ones(1, 1))) bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent) x_dist = TransformedDistribution(unit_normal, bijector) # x_dist = LogNormal(mu_latent, torch.pow(self.tau,-0.5)) pyro.observe("obs0", x_dist, self.data[0]) pyro.observe("obs1", x_dist, self.data[1]) return mu_latent def guide(): mu_q_log = pyro.param( "mu_q_log", Variable(self.log_mu_n.data + 0.17, requires_grad=True)) tau_q_log = pyro.param( "tau_q_log", Variable(self.log_tau_n.data - 0.143, requires_grad=True)) mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log) q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5)) pyro.sample("mu_latent", q_dist) kl_optim = KL_QP( model, guide, pyro.optim(torch.optim.Adam, { "lr": .0005, "betas": (0.96, 0.999) })) for k in range(9001): kl_optim.step() mu_error = torch.abs(pyro.param("mu_q_log") - self.log_mu_n).data.cpu().numpy()[0] tau_error = torch.abs(pyro.param("tau_q_log") - self.log_tau_n).data.cpu().numpy()[0] # print "mu_error", mu_error # print "tau_error", tau_error self.assertEqual(0.0, mu_error, prec=0.05) self.assertEqual(0.0, tau_error, prec=0.05)
def do_test_fixedness(self, model_fixed, guide_fixed): pyro._param_store._clear_cache() def model(): alpha_p_log = pyro.param( "alpha_p_log", Variable(self.alpha_p_log_0, requires_grad=True)) beta_p_log = pyro.param( "beta_p_log", Variable(self.beta_p_log_0, requires_grad=True)) alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log) lambda_latent = pyro.sample("lambda_latent", Gamma(alpha_p, beta_p)) x_dist = Poisson(lambda_latent) pyro.observe("obs", x_dist, self.data) return lambda_latent def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.alpha_q_log_0, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.beta_q_log_0, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) kl_optim = KL_QP(model, guide, pyro.optim(torch.optim.Adam, {"lr": .001}), model_fixed=model_fixed, guide_fixed=guide_fixed) for _ in range(10): kl_optim.step() model_unchanged = (torch.equal(pyro.param("alpha_p_log").data, self.alpha_p_log_0)) and\ (torch.equal(pyro.param("beta_p_log").data, self.beta_p_log_0)) guide_unchanged = (torch.equal(pyro.param("alpha_q_log").data, self.alpha_q_log_0)) and\ (torch.equal(pyro.param("beta_q_log").data, self.beta_q_log_0)) bad = (model_fixed and (not model_unchanged)) or (guide_fixed and (not guide_unchanged)) return (not bad)
def do_elbo_test(self, reparametrized, n_steps): pyro._param_store._clear_cache() def model(): prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5)) mu_latent = pyro.sample("mu_latent", prior_dist) x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5)) # x = pyro.observe("obs", x_dist, self.data) pyro.map_data(self.data, lambda i, x: pyro.observe("obs_%d" % i, x_dist, x), batch_size=1) return mu_latent def guide(): mu_q = pyro.param( "mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", Variable(self.analytic_log_sig_n.data - 0.09 * torch.ones(2), requires_grad=True)) sig_q = torch.exp(log_sig_q) q_dist = DiagNormal(mu_q, sig_q) q_dist.reparametrized = reparametrized pyro.sample("mu_latent", q_dist) pyro.map_data(self.data, lambda i, x: None, batch_size=1) kl_optim = KL_QP(model, guide, pyro.optim(torch.optim.Adam, {"lr": .001})) for k in range(n_steps): kl_optim.step() mu_error = torch.sum( torch.pow(self.analytic_mu_n - pyro.param("mu_q"), 2.0)) log_sig_error = torch.sum( torch.pow(self.analytic_log_sig_n - pyro.param("log_sig_q"), 2.0)) self.assertEqual(0.0, mu_error.data.cpu().numpy()[0], prec=0.02) self.assertEqual(0.0, log_sig_error.data.cpu().numpy()[0], prec=0.02)
def test_elbo_nonreparametrized(self): pyro._param_store._clear_cache() def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(self.alpha0, self.beta0)) x_dist = Exponential(lambda_latent) pyro.observe("obs0", x_dist, self.data[0]) pyro.observe("obs1", x_dist, self.data[1]) return lambda_latent def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) kl_optim = KL_QP( model, guide, pyro.optim(torch.optim.Adam, { "lr": .0003, "betas": (0.97, 0.999) })) for k in range(10001): kl_optim.step() alpha_error = torch.abs(pyro.param("alpha_q_log") - self.log_alpha_n).data.cpu().numpy()[0] beta_error = torch.abs(pyro.param("beta_q_log") - self.log_beta_n).data.cpu().numpy()[0] # print "alpha_error", alpha_error # print "beta_error", beta_error self.assertEqual(0.0, alpha_error, prec=0.05) self.assertEqual(0.0, beta_error, prec=0.05)
cll = pyro.sample("sample_cll", Categorical(alpha)) # return img return img, img_mu, cll def per_param_args(name, param): if name == "decoder": return {"lr": .0001} else: return {"lr": .0001} # or alternatively adam_params = {"lr": .0001} inference = KL_QP(model_latent, guide_latent, pyro.optim(optim.Adam, adam_params)) inference_c = KL_QP(model_given_c, guide_given_c, pyro.optim(optim.Adam, adam_params)) mnist_data = Variable(train_loader.dataset.train_data.float() / 255.) mnist_labels = Variable(train_loader.dataset.train_labels) mnist_size = mnist_data.size(0) batch_size = 128 # 64 # TODO: batches not necessarily all_batches = np.arange(0, mnist_size, batch_size) if all_batches[-1] != mnist_size: all_batches = list(all_batches) + [mnist_size] vis = visdom.Visdom(env='vae_z_c')
def model_sample(data, cll): classifier = pyro.module("classifier", pt_classify) alpha_cat = classifier.forward(data) cll = pyro.sample('observed_class', Categorical(alpha_cat)) return cll def guide(data, cll): return lambda foo: None # or alternatively adam_params = {"lr": .0001} inference_opt = KL_QP(model_obs, guide, pyro.optim(optim.Adam, adam_params)) mnist_data = Variable(train_loader.dataset.train_data.float() / 255.) mnist_labels = Variable(train_loader.dataset.train_labels) mnist_size = mnist_data.size(0) batch_size = 128 # 64 # TODO: batches not necessarily all_batches = np.arange(0, mnist_size, batch_size) if all_batches[-1] != mnist_size: all_batches = list(all_batches) + [mnist_size] vis = visdom.Visdom() for i in range(1000):
def model_sample_given_class(cll=None): pass def per_param_args(name, param): if name == "decoder": return {"lr": .0001} else: return {"lr": .0001} # or alternatively adam_params = {"lr": .0001} # optim.SGD(lr=.0001) inference_latent_class = KL_QP(model_latent, guide_latent, pyro.optim(optim.Adam, adam_params)) inference_observed_class = KL_QP( model_observed, guide_observed, pyro.optim( optim.Adam, adam_params)) inference_observed_class_scored = KL_QP( model_observed, guide_observed2, pyro.optim( optim.Adam, adam_params)) mnist_data = Variable(train_loader.dataset.train_data.float() / 255.) mnist_labels = Variable(train_loader.dataset.train_labels) mnist_size = mnist_data.size(0) batch_size = 128 # 64 mnist_data_test = Variable(test_loader.dataset.test_data.float() / 255.) mnist_labels_test = Variable(test_loader.dataset.test_labels)
def inspect_posterior_samples(i): c = local_guide(i, None) mean_param = Variable(torch.zeros(784), requires_grad=True) # do MLE for class means m = pyro.param("mean_of_class_" + str(c[0]), mean_param) sigma = Variable(torch.ones(m.size())) dat = pyro.sample("obs_" + str(i), DiagNormal(m, sigma)) return dat #grad_step = ELBo(local_model, local_guide, model_ML=true, optimizer="adam") optim_fct = pyro.optim(torch.optim.Adam, {'lr': .0001}) data = Variable(mnist.train_data).float() / 255. nr_samples = data.size(0) grad_step = KL_QP(local_model, local_guide, optim_fct) d0 = inspect_posterior_samples(0) d1 = inspect_posterior_samples(1) vis = visdom.Visdom() nr_epochs = 50 # apply it to minibatches of data by hand: for epoch in range(nr_epochs): total_loss = 0. for i in range(nr_samples): # print('starting datum '+str(i)) # mod_forward=local_model(i,data[i]) # mod_inv=local_guide(i,data[i]) loss_sample = grad_step(i, data[i])
DiagNormal(guide_mu_z, guide_sigma_z)) w_q = pyro.sample("factor_weight", DiagNormal(guide_mu_q_w, guide_sigma_q_w)) return z_q, w_q #grad_step = ELBo(local_model, local_guide, model_ML=true, optimizer="adam") adam_params = {"lr": .00000000000001} adam_optim = pyro.optim(torch.optim.Adam, adam_params) data = Variable(mnist.train_data).float() / 255. nr_samples = data.size(0) nr_epochs = 1000 grad_step = KL_QP(local_model, local_guide, adam_optim) # apply it to minibatches of data by hand: for j in range(nr_epochs): score = 0 for i in range(nr_batches): score_d = grad_step(i, data[i]) score += score_d / float(nr_samples) print('Local Score ' + str(-score)) print('Epoch score ' + str(-score)) # bb() #print('starting datum '+str(i)) # mod_forward=local_model(i,data[i]) # mod_inv=local_guide(i,data[i])
return cll def inspect_posterior_samples(i): cll = local_guide(i, None) mean_param = Variable(torch.zeros(1, 784), requires_grad=True) # do MLE for class means mu = pyro.param("mean_of_class_" + str(cll[0]), mean_param) dat = pyro.sample("obs_" + str(i), Bernoulli(mu)) return dat #grad_step = ELBo(local_model, local_guide, model_ML=true, optimizer="adam") optim_fct = pyro.optim(torch.optim.Adam, {'lr': .0001}) inference = KL_QP(local_model, local_guide, optim_fct) d0 = inspect_posterior_samples(0) d1 = inspect_posterior_samples(1) vis = visdom.Visdom() nr_epochs = 50 # apply it to minibatches of data by hand: mnist_data = Variable(train_loader.dataset.train_data.float() / 255.) mnist_labels = Variable(train_loader.dataset.train_labels) mnist_size = mnist_data.size(0) batch_size = 1 # 64 all_batches = np.arange(0, mnist_size, batch_size)
# return img return img, img_mu def per_param_args(name, param): if name == "decoder": return {"lr": .0001} else: return {"lr": .0001} # or alternatively adam_params = {"lr": .0001} # optim.SGD(lr=.0001) kl_optim = KL_QP(model, guide, pyro.optim(optim.Adam, adam_params)) #kl_optim = KL_QP(model, guide, pyro.optim(optim.Adam, per_param_args)) # num_steps = 1 mnist_data = Variable(train_loader.dataset.train_data.float() / 255.) mnist_size = mnist_data.size(0) batch_size = 512 # 64 # TODO: batches not necessarily all_batches = np.arange(0, mnist_size, batch_size) if all_batches[-1] != mnist_size: all_batches = list(all_batches) + [mnist_size] vis = visdom.Visdom(env='vae_mnist')
alpha = sample("alpha", Exponential(lam=(Tensor([1])))) else: alpha = (Tensor([alpha])) return core(prefix, alpha, condition, generate) def guide(prefix, condition=False, generate=0, alpha=None): if alpha is None: alpha_point_estimate = softplus(param("alpha-point-estimate", (torch.ones(1), requires_grad=True))) alpha = sample("alpha", Delta(v=alpha_point_estimate)) else: alpha = (Tensor([alpha])) return core(prefix, alpha, condition, generate) # Set up neural net parameter inference optimizer = pyro.optim(torch.optim.Adam, { "lr": .005, "betas": (0.97, 0.999) }) infer = KL_QP(model, guide, optimizer) # Set up softmax alpha inference alpha_optimizer = pyro.optim(torch.optim.SGD, { "lr": .005, "momentum": 0.1 }) char_names = ["char_{0}".format(i) for i in range(1000)] alpha_model = block(model, expose=["alpha", "alpha-point-estimate"] + char_names) alpha_guide = block(guide, expose=["alpha", "alpha-point-estimate"] + char_names) alpha_infer = KL_QP(alpha_model, alpha_guide, alpha_optimizer) for k in range(10000000): # Draw a random text sample chunk = dataset.random_chunk(chunk_len=200) # Fit alpha to current text sample, keeping neural net params fixed prev_alpha_loss = float("-inf")
# and try importance! def guide(): latent = pyro.sample( "latent", DiagNormal(Variable(torch.zeros(1)), 5 * Variable(torch.ones(1)))) x_dist = DiagNormal(latent, Variable(torch.ones(1))) pass infer = ImportanceSampling(model, guide) exp = lw_expectation(infer, lambda x: x, 100) print(exp) # and try VI! def guide(): mf_m = pyro.param("mf_m", Variable(torch.zeros(1))) mf_v = pyro.param("mf_v", Variable(torch.ones(1))) latent = pyro.sample("latent", DiagNormal(mf_m, mf_v)) pass infer = KL_QP(model, guide) exp = lw_expectation(infer, lambda x: x, 100) print(exp)