def local_guide(i, datum):
    dim_z = 2
    dim_o = 784
    nr_samples = 1
    alpha = torch.ones(nr_samples, dim_z) * 0.1
    mu_q_z = Variable(alpha, requires_grad=True)
    log_sigma_q_z = Variable(torch.ones(mu_q_z.size()), requires_grad=True)
    sigma_q_z = torch.exp(log_sigma_q_z)

    mu_q_w = Variable(torch.ones(dim_z, dim_o), requires_grad=True)
    log_sigma_q_w = Variable(torch.ones(dim_z, dim_o), requires_grad=True)

    guide_mu_q_w = pyro.param("factor_weight_mean", mu_q_w)
    guide_log_sigma_q_w = pyro.param("factor_weight_log_sigma", log_sigma_q_w)
    #sigma_q_w = torch.exp(log_sigma_q_w)
    guide_sigma_q_w = torch.exp(guide_log_sigma_q_w)

    guide_mu_z = pyro.param("embedding_posterior_mean_", mu_q_z)

    guide_log_sigma_q_z = pyro.param("embedding_posterior_sigma_",
                                     log_sigma_q_z)
    guide_sigma_z = torch.exp(guide_log_sigma_q_z)
    z_q = pyro.sample("embedding_of_datum_" + str(i),
                      DiagNormal(guide_mu_z, guide_sigma_z))

    w_q = pyro.sample("factor_weight", DiagNormal(guide_mu_q_w,
                                                  guide_sigma_q_w))

    return z_q, w_q
示例#2
0
def model():
    latent = pyro.sample(
        "latent",
        DiagNormal(Variable(torch.zeros(1)), 5 * Variable(torch.ones(1))))
    x_dist = DiagNormal(latent, Variable(torch.ones(1)))
    x = pyro.observe("obs", x_dist, Variable(torch.ones(1)))
    return latent
示例#3
0
 def model():
     prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5))
     mu_latent = pyro.sample("mu_latent", prior_dist)
     x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5))
     # x = pyro.observe("obs", x_dist, self.data)
     pyro.map_data(self.data,
                   lambda i, x: pyro.observe("obs_%d" % i, x_dist, x),
                   batch_size=1)
     return mu_latent
示例#4
0
 def guide():
     mu_q_log = pyro.param(
         "mu_q_log",
         Variable(self.log_mu_n.data + 0.17, requires_grad=True))
     tau_q_log = pyro.param(
         "tau_q_log",
         Variable(self.log_tau_n.data - 0.143, requires_grad=True))
     mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log)
     q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5))
     q_dist.reparametrized = reparametrized
     pyro.sample("mu_latent", q_dist)
示例#5
0
 def guide():
     mu_q = pyro.param(
         "mu_q",
         Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2),
                  requires_grad=True))
     log_sig_q = pyro.param(
         "log_sig_q",
         Variable(self.analytic_log_sig_n.data - 0.09 * torch.ones(2),
                  requires_grad=True))
     sig_q = torch.exp(log_sig_q)
     q_dist = DiagNormal(mu_q, sig_q)
     q_dist.reparametrized = reparametrized
     pyro.sample("mu_latent", q_dist)
     pyro.map_data(self.data, lambda i, x: None, batch_size=1)
def local_model(i, data):
    dim_z = 2
    dim_o = 784
    nr_samples = 1

    nr_data = data.size(0)

    #global variables
    mu_w = Variable(torch.ones(dim_z, dim_o), requires_grad=False)
    log_sigma_w = Variable(torch.ones(dim_z, dim_o), requires_grad=False)
    sigma_w = torch.exp(log_sigma_w)
    weight = pyro.sample("factor_weight", DiagNormal(mu_w, sigma_w))

    def sub_model(datum):
        mu_latent = Variable(torch.ones(nr_samples, dim_z)) * 0.5
        sigma_latent = Variable(torch.ones(mu_latent.size()))
        z = pyro.sample("embedding_of_datum_" + str(i),
                        DiagNormal(mu_latent, sigma_latent))
        mean_beta = z.mm(weight)
        beta = sigmoid(mean_beta)
        pyro.observe("obs_" + str(i), Bernoulli(beta), datum)

    for i in range(nr_data):
        sub_model(data[i])

    return z, weight
示例#7
0
 def guide():
     mu_q = pyro.param("mu_q",
                       Variable(torch.zeros(1), requires_grad=True))
     log_sig_q = pyro.param(
         "log_sig_q", Variable(torch.zeros(1), requires_grad=True))
     sig_q = torch.exp(log_sig_q)
     pyro.sample("mu_latent", DiagNormal(mu_q, sig_q))
示例#8
0
def model_sample(cll=None):
    # wrap params for use in model -- required
    # decoder = pyro.module("decoder", pt_decode)

    # sample from prior
    z_mu, z_sigma = Variable(torch.zeros(
        [1, 20])), Variable(torch.ones([1, 20]))

    # sample
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    alpha = Variable(torch.ones([1, 10]) / 10.)

    if cll.data.cpu().numpy() is None:
        bb()
        cll = pyro.sample('class', Categorical(alpha))
        print('sampling class')

    # decode into size of imgx1 for mu
    img_mu = pt_decode.forward(z, cll)
    # bb()
    # img=Bernoulli(img_mu).sample()
    # score against actual images
    img = pyro.sample("sample", Bernoulli(img_mu))
    # return img
    return img, img_mu
示例#9
0
 def model():
     mu_latent = pyro.sample(
         "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5)))
     x_dist = LogNormal(mu_latent, torch.pow(self.tau, -0.5))
     pyro.observe("obs0", x_dist, self.data[0])
     pyro.observe("obs1", x_dist, self.data[1])
     return mu_latent
示例#10
0
文件: vae_z_c.py 项目: zaxtax/pyro
def model_given_c(data, cll):
    decoder_c = pyro.module("decoder_c", pt_decode_c)
    decoder_z = pyro.module("decoder_z", pt_decode_z)
    z_mu, z_sigma = decoder_c.forward(cll)
    z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
    img_mu = decoder_z.forward(z)
    pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
示例#11
0
文件: vae.py 项目: zaxtax/pyro
def guide(data):
    # wrap params for use in model -- required
    encoder = pyro.module("encoder", pt_encode)

    # use the ecnoder to get an estimate of mu, sigma
    z_mu, z_sigma = encoder.forward(data)

    pyro.sample("latent", DiagNormal(z_mu, z_sigma))
 def sub_model(datum):
     mu_latent = Variable(torch.ones(nr_samples, dim_z)) * 0.5
     sigma_latent = Variable(torch.ones(mu_latent.size()))
     z = pyro.sample("embedding_of_datum_" + str(i),
                     DiagNormal(mu_latent, sigma_latent))
     mean_beta = z.mm(weight)
     beta = sigmoid(mean_beta)
     pyro.observe("obs_" + str(i), Bernoulli(beta), datum)
示例#13
0
文件: vae_z_c.py 项目: zaxtax/pyro
def guide_latent(data, cll):
    encoder_x = pyro.module("encoder_x", pt_encode_x)
    encoder_z = pyro.module("encoder_z", pt_encode_z)

    z_mu, z_sigma = encoder_x.forward(data)
    z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
    alpha_cat = encoder_z.forward(z)
    pyro.sample("latent_class", Categorical(alpha_cat))
示例#14
0
def model_xz(data, foo):
    decoder_xz = pyro.module("decoder_xz", pt_decode_xz)
    z_mu, z_sigma = Variable(torch.zeros([data.size(0), 20])), Variable(
        torch.ones([data.size(0), 20]))
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))
    img_mu = decoder_xz.forward(z)
    pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
    return z
示例#15
0
def guide_latent2(data):
    encoder_c = pyro.module("encoder_c", pt_encode_c)
    alpha = encoder_c.forward(data)
    cll = pyro.sample("latent_class", Categorical(alpha))

    encoder = pyro.module("encoder_o", pt_encode_o)
    z_mu, z_sigma = encoder.forward(data, cll)
    z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
示例#16
0
def inspect_posterior_samples(i):
    c = local_guide(i, None)
    mean_param = Variable(torch.zeros(784), requires_grad=True)
    # do MLE for class means
    m = pyro.param("mean_of_class_" + str(c[0]), mean_param)
    sigma = Variable(torch.ones(m.size()))
    dat = pyro.sample("obs_" + str(i), DiagNormal(m, sigma))
    return dat
示例#17
0
文件: vae_z_c.py 项目: zaxtax/pyro
def model_latent(data):
    decoder_c = pyro.module("decoder_c", pt_decode_c)
    decoder_z = pyro.module("decoder_z", pt_decode_z)
    alpha = Variable(torch.ones([data.size(0), 10])) / 10.
    cll = pyro.sample('latent_class', Categorical(alpha))
    z_mu, z_sigma = decoder_c.forward(cll)
    z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
    img_mu = decoder_z.forward(z)
    pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
示例#18
0
def local_model(i, datum):
    beta = Variable(torch.ones(1)) * 0.5
    c = pyro.sample("class_of_datum_" + str(i), Bernoulli(beta))
    mean_param = Variable(torch.zeros(784), requires_grad=True)
    # do MLE for class means
    m = pyro.param("mean_of_class_" + str(c[0]), mean_param)

    sigma = Variable(torch.ones(m.size()))
    pyro.observe("obs_" + str(i), DiagNormal(m, sigma), datum)
    return c
示例#19
0
文件: vae.py 项目: zaxtax/pyro
def model(data):
    # klqp gets called with data.

    # wrap params for use in model -- required
    decoder = pyro.module("decoder", pt_decode)

    # sample from prior
    z_mu, z_sigma = pyro.ng_zeros([data.size(0),
                                   20]), pyro.ng_ones([data.size(0), 20])
    # Variable(torch.zeros([data.size(0), 20])), Variable(torch.ones([data.size(0), 20]))

    # sample (retrieve value set by the guide)
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    # decode into size of imgx2 for mu/sigma
    img_mu, img_sigma = decoder.forward(z)

    # score against actual images
    pyro.observe("obs", DiagNormal(img_mu, img_sigma), data.view(-1, 784))
示例#20
0
def model_sample():
    z_mu, z_sigma = Variable(torch.zeros([1,
                                          20])), Variable(torch.ones([1, 20]))
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    img_mu = pt_decode_xz.forward(z)
    alpha_mu = pt_decode_c.forward(z)

    img = pyro.sample("sample_img", Bernoulli(img_mu))
    cll = pyro.sample("sample_cll", Categorical(alpha_mu))
    return img, img_mu, cll
示例#21
0
 def model():
     mu_latent = pyro.sample(
         "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5)))
     unit_normal = dist.DiagNormal(Variable(torch.zeros(1, 1)),
                                   Variable(torch.ones(1, 1)))
     bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
     x_dist = TransformedDistribution(unit_normal, bijector)
     # x_dist = LogNormal(mu_latent, torch.pow(self.tau,-0.5))
     pyro.observe("obs0", x_dist, self.data[0])
     pyro.observe("obs1", x_dist, self.data[1])
     return mu_latent
示例#22
0
def model_latent(data):

    # wrap params for use in model -- required
    decoder = pyro.module("decoder", pt_decode)

    # sample from prior
    z_mu, z_sigma = Variable(torch.zeros([data.size(0), 20])), Variable(
        torch.ones([data.size(0), 20]))

    # sample
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    # decode into size of imgx2 for mu/sigma
    img_mu, alpha_mult = decoder.forward(z)

    # score against actual images
    pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
示例#23
0
文件: vae.py 项目: zaxtax/pyro
def model_sample():

    # wrap params for use in model -- required
    decoder = pyro.module("decoder", pt_decode)

    # sample from prior
    z_mu, z_sigma = Variable(torch.zeros([1,
                                          20])), Variable(torch.ones([1, 20]))

    # sample
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    # decode into size of imgx2 for mu/sigma
    img_mu, img_sigma = decoder.forward(z)

    # score against actual images
    #img= pyro.sample("obs", DiagNormal(img_mu, img_sigma))
    return img_mu
示例#24
0
def model_latent_backup(data):
    # wrap params for use in model -- required
    decoder = pyro.module("decoder", pt_decode)
    # sample from prior
    z_mu, z_sigma = Variable(torch.zeros([data.size(0), 20])), Variable(
        torch.ones([data.size(0), 20]))

    # sample
    z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))

    alpha = Variable(torch.ones([data.size(0), 10])) / 10.
    # c = pyro.sample('latent_class', Multinomial(alpha,1))#Categorical(alpha))
    cll = pyro.sample('latent_class', Categorical(alpha))  # Categorical(alpha))

    # bb()
    # decode into size of imgx2 for mu/sigma
    img_mu = decoder.forward(z, cll)
    # score against actual images
    pyro.observe("obs", Bernoulli(img_mu), data.view(-1, 784))
示例#25
0
def model_sample():
    # wrap params for use in model -- required
    # decoder = pyro.module("decoder", pt_decode)

    # sample from prior
    z_mu, z_sigma = Variable(torch.zeros([1,
                                          20])), Variable(torch.ones([1, 20]))

    # sample
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))

    # decode into size of imgx1 for mu
    img_mu = pt_decode.forward(z)
    # bb()
    # img=Bernoulli(img_mu).sample()
    # score against actual images
    img = pyro.sample("sample", Bernoulli(img_mu))
    # return img
    return img, img_mu
示例#26
0
def guide():
    latent = pyro.sample(
        "latent",
        DiagNormal(Variable(torch.zeros(1)), 5 * Variable(torch.ones(1))))
    x_dist = DiagNormal(latent, Variable(torch.ones(1)))
    pass
示例#27
0
def guide():
    mf_m = pyro.param("mf_m", Variable(torch.zeros(1)))
    mf_v = pyro.param("mf_v", Variable(torch.ones(1)))
    latent = pyro.sample("latent", DiagNormal(mf_m, mf_v))
    pass
示例#28
0
 def model():
     prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5))
     mu_latent = pyro.sample("mu_latent", prior_dist)
     x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5))
     pyro.observe("obs", x_dist, self.data)
     return mu_latent
示例#29
0
def guide_observed(data, cll):
    encoder = pyro.module("encoder_o", pt_encode_o)
    z_mu, z_sigma = encoder.forward(data, cll)
    z = pyro.sample("latent_z", DiagNormal(z_mu, z_sigma))
示例#30
0
def guide_latent(data, cll):
    encoder = pyro.module("encoder_xz", pt_encode_xz)
    z_mu, z_sigma = encoder.forward(data)
    z = pyro.sample("latent", DiagNormal(z_mu, z_sigma))
    return z