def local_guide(i, datum): dim_z = 2 dim_o = 784 nr_samples = 1 alpha = torch.ones(nr_samples, dim_z) * 0.1 mu_q_z = Variable(alpha, requires_grad=True) log_sigma_q_z = Variable(torch.ones(mu_q_z.size()), requires_grad=True) sigma_q_z = torch.exp(log_sigma_q_z) mu_q_w = Variable(torch.ones(dim_z, dim_o), requires_grad=True) log_sigma_q_w = Variable(torch.ones(dim_z, dim_o), requires_grad=True) guide_mu_q_w = pyro.param("factor_weight_mean", mu_q_w) guide_log_sigma_q_w = pyro.param("factor_weight_log_sigma", log_sigma_q_w) #sigma_q_w = torch.exp(log_sigma_q_w) guide_sigma_q_w = torch.exp(guide_log_sigma_q_w) guide_mu_z = pyro.param("embedding_posterior_mean_", mu_q_z) guide_log_sigma_q_z = pyro.param("embedding_posterior_sigma_", log_sigma_q_z) guide_sigma_z = torch.exp(guide_log_sigma_q_z) z_q = pyro.sample("embedding_of_datum_" + str(i), DiagNormal(guide_mu_z, guide_sigma_z)) w_q = pyro.sample("factor_weight", DiagNormal(guide_mu_q_w, guide_sigma_q_w)) return z_q, w_q
def model(): latent = pyro.sample( "latent", DiagNormal(Variable(torch.zeros(1)), 5 * Variable(torch.ones(1)))) x_dist = DiagNormal(latent, Variable(torch.ones(1))) x = pyro.observe("obs", x_dist, Variable(torch.ones(1))) return latent
def model(): prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5)) mu_latent = pyro.sample("mu_latent", prior_dist) x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5)) # x = pyro.observe("obs", x_dist, self.data) pyro.map_data(self.data, lambda i, x: pyro.observe("obs_%d" % i, x_dist, x), batch_size=1) return mu_latent
def guide(): mu_q_log = pyro.param( "mu_q_log", Variable(self.log_mu_n.data + 0.17, requires_grad=True)) tau_q_log = pyro.param( "tau_q_log", Variable(self.log_tau_n.data - 0.143, requires_grad=True)) mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log) q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5)) q_dist.reparametrized = reparametrized pyro.sample("mu_latent", q_dist)
def guide(): mu_q = pyro.param( "mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", Variable(self.analytic_log_sig_n.data - 0.09 * torch.ones(2), requires_grad=True)) sig_q = torch.exp(log_sig_q) q_dist = DiagNormal(mu_q, sig_q) q_dist.reparametrized = reparametrized pyro.sample("mu_latent", q_dist) pyro.map_data(self.data, lambda i, x: None, batch_size=1)
def local_model(i, data): dim_z = 2 dim_o = 784 nr_samples = 1 nr_data = data.size(0) #global variables mu_w = Variable(torch.ones(dim_z, dim_o), requires_grad=False) log_sigma_w = Variable(torch.ones(dim_z, dim_o), requires_grad=False) sigma_w = torch.exp(log_sigma_w) weight = pyro.sample("factor_weight", DiagNormal(mu_w, sigma_w)) def sub_model(datum): mu_latent = Variable(torch.ones(nr_samples, dim_z)) * 0.5 sigma_latent = Variable(torch.ones(mu_latent.size())) z = pyro.sample("embedding_of_datum_" + str(i), DiagNormal(mu_latent, sigma_latent)) mean_beta = z.mm(weight) beta = sigmoid(mean_beta) pyro.observe("obs_" + str(i), Bernoulli(beta), datum) for i in range(nr_data): sub_model(data[i]) return z, weight
def guide(): mu_q = pyro.param("mu_q", Variable(torch.zeros(1), requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", Variable(torch.zeros(1), requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("mu_latent", DiagNormal(mu_q, sig_q))
def model_sample(cll=None): # wrap params for use in model -- required # decoder = pyro.module("decoder", pt_decode) # sample from prior z_mu, z_sigma = Variable(torch.zeros( [1, 20])), Variable(torch.ones([1, 20])) # sample z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) alpha = Variable(torch.ones([1, 10]) / 10.) if cll.data.cpu().numpy() is None: bb() cll = pyro.sample('class', Categorical(alpha)) print('sampling class') # decode into size of imgx1 for mu img_mu = pt_decode.forward(z, cll) # bb() # img=Bernoulli(img_mu).sample() # score against actual images img = pyro.sample("sample", Bernoulli(img_mu)) # return img return img, img_mu
def model(): mu_latent = pyro.sample( "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5))) x_dist = LogNormal(mu_latent, torch.pow(self.tau, -0.5)) pyro.observe("obs0", x_dist, self.data[0]) pyro.observe("obs1", x_dist, self.data[1]) return mu_latent
def model_given_c(data, cll): decoder_c = pyro.module("decoder_c", pt_decode_c) decoder_z = pyro.module("decoder_z", pt_decode_z) z_mu, z_sigma = decoder_c.forward(cll) z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma)) img_mu = decoder_z.forward(z) pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
def guide(data): # wrap params for use in model -- required encoder = pyro.module("encoder", pt_encode) # use the ecnoder to get an estimate of mu, sigma z_mu, z_sigma = encoder.forward(data) pyro.sample("latent", DiagNormal(z_mu, z_sigma))
def sub_model(datum): mu_latent = Variable(torch.ones(nr_samples, dim_z)) * 0.5 sigma_latent = Variable(torch.ones(mu_latent.size())) z = pyro.sample("embedding_of_datum_" + str(i), DiagNormal(mu_latent, sigma_latent)) mean_beta = z.mm(weight) beta = sigmoid(mean_beta) pyro.observe("obs_" + str(i), Bernoulli(beta), datum)
def guide_latent(data, cll): encoder_x = pyro.module("encoder_x", pt_encode_x) encoder_z = pyro.module("encoder_z", pt_encode_z) z_mu, z_sigma = encoder_x.forward(data) z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma)) alpha_cat = encoder_z.forward(z) pyro.sample("latent_class", Categorical(alpha_cat))
def model_xz(data, foo): decoder_xz = pyro.module("decoder_xz", pt_decode_xz) z_mu, z_sigma = Variable(torch.zeros([data.size(0), 20])), Variable( torch.ones([data.size(0), 20])) z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) img_mu = decoder_xz.forward(z) pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784)) return z
def guide_latent2(data): encoder_c = pyro.module("encoder_c", pt_encode_c) alpha = encoder_c.forward(data) cll = pyro.sample("latent_class", Categorical(alpha)) encoder = pyro.module("encoder_o", pt_encode_o) z_mu, z_sigma = encoder.forward(data, cll) z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
def inspect_posterior_samples(i): c = local_guide(i, None) mean_param = Variable(torch.zeros(784), requires_grad=True) # do MLE for class means m = pyro.param("mean_of_class_" + str(c[0]), mean_param) sigma = Variable(torch.ones(m.size())) dat = pyro.sample("obs_" + str(i), DiagNormal(m, sigma)) return dat
def model_latent(data): decoder_c = pyro.module("decoder_c", pt_decode_c) decoder_z = pyro.module("decoder_z", pt_decode_z) alpha = Variable(torch.ones([data.size(0), 10])) / 10. cll = pyro.sample('latent_class', Categorical(alpha)) z_mu, z_sigma = decoder_c.forward(cll) z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma)) img_mu = decoder_z.forward(z) pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
def local_model(i, datum): beta = Variable(torch.ones(1)) * 0.5 c = pyro.sample("class_of_datum_" + str(i), Bernoulli(beta)) mean_param = Variable(torch.zeros(784), requires_grad=True) # do MLE for class means m = pyro.param("mean_of_class_" + str(c[0]), mean_param) sigma = Variable(torch.ones(m.size())) pyro.observe("obs_" + str(i), DiagNormal(m, sigma), datum) return c
def model(data): # klqp gets called with data. # wrap params for use in model -- required decoder = pyro.module("decoder", pt_decode) # sample from prior z_mu, z_sigma = pyro.ng_zeros([data.size(0), 20]), pyro.ng_ones([data.size(0), 20]) # Variable(torch.zeros([data.size(0), 20])), Variable(torch.ones([data.size(0), 20])) # sample (retrieve value set by the guide) z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) # decode into size of imgx2 for mu/sigma img_mu, img_sigma = decoder.forward(z) # score against actual images pyro.observe("obs", DiagNormal(img_mu, img_sigma), data.view(-1, 784))
def model_sample(): z_mu, z_sigma = Variable(torch.zeros([1, 20])), Variable(torch.ones([1, 20])) z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) img_mu = pt_decode_xz.forward(z) alpha_mu = pt_decode_c.forward(z) img = pyro.sample("sample_img", Bernoulli(img_mu)) cll = pyro.sample("sample_cll", Categorical(alpha_mu)) return img, img_mu, cll
def model(): mu_latent = pyro.sample( "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5))) unit_normal = dist.DiagNormal(Variable(torch.zeros(1, 1)), Variable(torch.ones(1, 1))) bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent) x_dist = TransformedDistribution(unit_normal, bijector) # x_dist = LogNormal(mu_latent, torch.pow(self.tau,-0.5)) pyro.observe("obs0", x_dist, self.data[0]) pyro.observe("obs1", x_dist, self.data[1]) return mu_latent
def model_latent(data): # wrap params for use in model -- required decoder = pyro.module("decoder", pt_decode) # sample from prior z_mu, z_sigma = Variable(torch.zeros([data.size(0), 20])), Variable( torch.ones([data.size(0), 20])) # sample z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) # decode into size of imgx2 for mu/sigma img_mu, alpha_mult = decoder.forward(z) # score against actual images pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
def model_sample(): # wrap params for use in model -- required decoder = pyro.module("decoder", pt_decode) # sample from prior z_mu, z_sigma = Variable(torch.zeros([1, 20])), Variable(torch.ones([1, 20])) # sample z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) # decode into size of imgx2 for mu/sigma img_mu, img_sigma = decoder.forward(z) # score against actual images #img= pyro.sample("obs", DiagNormal(img_mu, img_sigma)) return img_mu
def model_latent_backup(data): # wrap params for use in model -- required decoder = pyro.module("decoder", pt_decode) # sample from prior z_mu, z_sigma = Variable(torch.zeros([data.size(0), 20])), Variable( torch.ones([data.size(0), 20])) # sample z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma)) alpha = Variable(torch.ones([data.size(0), 10])) / 10. # c = pyro.sample('latent_class', Multinomial(alpha,1))#Categorical(alpha)) cll = pyro.sample('latent_class', Categorical(alpha)) # Categorical(alpha)) # bb() # decode into size of imgx2 for mu/sigma img_mu = decoder.forward(z, cll) # score against actual images pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
def model_sample(): # wrap params for use in model -- required # decoder = pyro.module("decoder", pt_decode) # sample from prior z_mu, z_sigma = Variable(torch.zeros([1, 20])), Variable(torch.ones([1, 20])) # sample z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) # decode into size of imgx1 for mu img_mu = pt_decode.forward(z) # bb() # img=Bernoulli(img_mu).sample() # score against actual images img = pyro.sample("sample", Bernoulli(img_mu)) # return img return img, img_mu
def guide(): latent = pyro.sample( "latent", DiagNormal(Variable(torch.zeros(1)), 5 * Variable(torch.ones(1)))) x_dist = DiagNormal(latent, Variable(torch.ones(1))) pass
def guide(): mf_m = pyro.param("mf_m", Variable(torch.zeros(1))) mf_v = pyro.param("mf_v", Variable(torch.ones(1))) latent = pyro.sample("latent", DiagNormal(mf_m, mf_v)) pass
def model(): prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5)) mu_latent = pyro.sample("mu_latent", prior_dist) x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5)) pyro.observe("obs", x_dist, self.data) return mu_latent
def guide_observed(data, cll): encoder = pyro.module("encoder_o", pt_encode_o) z_mu, z_sigma = encoder.forward(data, cll) z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
def guide_latent(data, cll): encoder = pyro.module("encoder_xz", pt_encode_xz) z_mu, z_sigma = encoder.forward(data) z = pyro.sample("latent", DiagNormal(z_mu, z_sigma)) return z