Esempio n. 1
0
    def do_test_per_param_optim(self, fixed_param, free_param):
        pyro._param_store._clear_cache()

        def model():
            prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5))
            mu_latent = pyro.sample("mu_latent", prior_dist)
            x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5))
            pyro.observe("obs", x_dist, self.data)
            return mu_latent

        def guide():
            mu_q = pyro.param("mu_q",
                              Variable(torch.zeros(1), requires_grad=True))
            log_sig_q = pyro.param(
                "log_sig_q", Variable(torch.zeros(1), requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample("mu_latent", DiagNormal(mu_q, sig_q))

        def optim_params(param_name, param):
            if param_name == fixed_param:
                return {'lr': 0.00}
            elif param_name == free_param:
                return {'lr': 0.01}

        kl_optim = KL_QP(model, guide,
                         pyro.optim(torch.optim.Adam, optim_params))
        for k in range(3):
            kl_optim.step()

        free_param_unchanged = torch.equal(
            pyro.param(free_param).data, torch.zeros(1))
        fixed_param_unchanged = torch.equal(
            pyro.param(fixed_param).data, torch.zeros(1))
        passed_test = fixed_param_unchanged and not free_param_unchanged
        self.assertTrue(passed_test)
Esempio n. 2
0
    def test_elbo_nonreparametrized(self):
        pyro._param_store._clear_cache()

        def model():
            p_latent = pyro.sample("p_latent", Beta(self.alpha0, self.beta0))
            x_dist = Bernoulli(p_latent)
            pyro.map_data(self.data,
                          lambda i, x: pyro.observe("obs", x_dist, x),
                          batch_size=2)
            return p_latent

        def guide():
            alpha_q_log = pyro.param(
                "alpha_q_log",
                Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
            beta_q_log = pyro.param(
                "beta_q_log",
                Variable(self.log_beta_n.data - 0.143, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            pyro.sample("p_latent", Beta(alpha_q, beta_q))
            pyro.map_data(self.data, lambda i, x: None, batch_size=2)

        kl_optim = KL_QP(
            model, guide,
            pyro.optim(torch.optim.Adam, {
                "lr": .001,
                "betas": (0.97, 0.999)
            }))
        for k in range(6001):
            kl_optim.step()


#             if k%1000==0:
#                 print "alpha_q", torch.exp(pyro.param("alpha_q_log")).data.numpy()[0]
#                 print "beta_q", torch.exp(pyro.param("beta_q_log")).data.numpy()[0]
#
#         print "alpha_n", self.alpha_n.data.numpy()[0]
#         print "beta_n", self.beta_n.data.numpy()[0]
#         print "alpha_0", self.alpha0.data.numpy()[0]
#         print "beta_0", self.beta0.data.numpy()[0]

        alpha_error = torch.abs(pyro.param("alpha_q_log") -
                                self.log_alpha_n).data.cpu().numpy()[0]
        beta_error = torch.abs(pyro.param("beta_q_log") -
                               self.log_beta_n).data.cpu().numpy()[0]
        # print "alpha_error", alpha_error
        # print "beta_error", beta_error
        self.assertEqual(0.0, alpha_error, prec=0.05)
        self.assertEqual(0.0, beta_error, prec=0.05)
Esempio n. 3
0
    def do_elbo_test(self, reparametrized, n_steps):
        pyro._param_store._clear_cache()

        def model():
            mu_latent = pyro.sample(
                "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5)))
            x_dist = LogNormal(mu_latent, torch.pow(self.tau, -0.5))
            pyro.observe("obs0", x_dist, self.data[0])
            pyro.observe("obs1", x_dist, self.data[1])
            return mu_latent

        def guide():
            mu_q_log = pyro.param(
                "mu_q_log",
                Variable(self.log_mu_n.data + 0.17, requires_grad=True))
            tau_q_log = pyro.param(
                "tau_q_log",
                Variable(self.log_tau_n.data - 0.143, requires_grad=True))
            mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log)
            q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5))
            q_dist.reparametrized = reparametrized
            pyro.sample("mu_latent", q_dist)

        kl_optim = KL_QP(
            model, guide,
            pyro.optim(torch.optim.Adam, {
                "lr": .0005,
                "betas": (0.96, 0.999)
            }))
        for k in range(n_steps):
            kl_optim.step()
            # if k%1000==0:
            #    print "log_mu_q",  pyro.param("mu_q_log").data.numpy()[0]
            #    print "log_tau_q", pyro.param("tau_q_log").data.numpy()[0]

        # print "log_mu_n", self.log_mu_n.data.numpy()[0]
        # print "log_tau_n", self.log_tau_n.data.numpy()[0]

        mu_error = torch.abs(pyro.param("mu_q_log") -
                             self.log_mu_n).data.cpu().numpy()[0]
        tau_error = torch.abs(pyro.param("tau_q_log") -
                              self.log_tau_n).data.cpu().numpy()[0]
        # print "mu_error", mu_error
        # print "tau_error", tau_error
        self.assertEqual(0.0, mu_error, prec=0.07)
        self.assertEqual(0.0, tau_error, prec=0.07)
Esempio n. 4
0
    def test_elbo_with_transformed_distribution(self):
        pyro._param_store._clear_cache()

        def model():
            mu_latent = pyro.sample(
                "mu_latent", DiagNormal(self.mu0, torch.pow(self.tau0, -0.5)))
            unit_normal = dist.DiagNormal(Variable(torch.zeros(1, 1)),
                                          Variable(torch.ones(1, 1)))
            bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
            x_dist = TransformedDistribution(unit_normal, bijector)
            # x_dist = LogNormal(mu_latent, torch.pow(self.tau,-0.5))
            pyro.observe("obs0", x_dist, self.data[0])
            pyro.observe("obs1", x_dist, self.data[1])
            return mu_latent

        def guide():
            mu_q_log = pyro.param(
                "mu_q_log",
                Variable(self.log_mu_n.data + 0.17, requires_grad=True))
            tau_q_log = pyro.param(
                "tau_q_log",
                Variable(self.log_tau_n.data - 0.143, requires_grad=True))
            mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log)
            q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5))
            pyro.sample("mu_latent", q_dist)

        kl_optim = KL_QP(
            model, guide,
            pyro.optim(torch.optim.Adam, {
                "lr": .0005,
                "betas": (0.96, 0.999)
            }))
        for k in range(9001):
            kl_optim.step()

        mu_error = torch.abs(pyro.param("mu_q_log") -
                             self.log_mu_n).data.cpu().numpy()[0]
        tau_error = torch.abs(pyro.param("tau_q_log") -
                              self.log_tau_n).data.cpu().numpy()[0]
        # print "mu_error", mu_error
        # print "tau_error", tau_error
        self.assertEqual(0.0, mu_error, prec=0.05)
        self.assertEqual(0.0, tau_error, prec=0.05)
Esempio n. 5
0
    def do_test_fixedness(self, model_fixed, guide_fixed):
        pyro._param_store._clear_cache()

        def model():
            alpha_p_log = pyro.param(
                "alpha_p_log", Variable(self.alpha_p_log_0,
                                        requires_grad=True))
            beta_p_log = pyro.param(
                "beta_p_log", Variable(self.beta_p_log_0, requires_grad=True))
            alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log)
            lambda_latent = pyro.sample("lambda_latent",
                                        Gamma(alpha_p, beta_p))
            x_dist = Poisson(lambda_latent)
            pyro.observe("obs", x_dist, self.data)
            return lambda_latent

        def guide():
            alpha_q_log = pyro.param(
                "alpha_q_log", Variable(self.alpha_q_log_0,
                                        requires_grad=True))
            beta_q_log = pyro.param(
                "beta_q_log", Variable(self.beta_q_log_0, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            pyro.sample("lambda_latent", Gamma(alpha_q, beta_q))

        kl_optim = KL_QP(model,
                         guide,
                         pyro.optim(torch.optim.Adam, {"lr": .001}),
                         model_fixed=model_fixed,
                         guide_fixed=guide_fixed)
        for _ in range(10):
            kl_optim.step()

        model_unchanged = (torch.equal(pyro.param("alpha_p_log").data, self.alpha_p_log_0)) and\
                          (torch.equal(pyro.param("beta_p_log").data, self.beta_p_log_0))
        guide_unchanged = (torch.equal(pyro.param("alpha_q_log").data, self.alpha_q_log_0)) and\
                          (torch.equal(pyro.param("beta_q_log").data, self.beta_q_log_0))
        bad = (model_fixed and
               (not model_unchanged)) or (guide_fixed and
                                          (not guide_unchanged))
        return (not bad)
Esempio n. 6
0
    def do_elbo_test(self, reparametrized, n_steps):
        pyro._param_store._clear_cache()

        def model():
            prior_dist = DiagNormal(self.mu0, torch.pow(self.lam0, -0.5))
            mu_latent = pyro.sample("mu_latent", prior_dist)
            x_dist = DiagNormal(mu_latent, torch.pow(self.lam, -0.5))
            # x = pyro.observe("obs", x_dist, self.data)
            pyro.map_data(self.data,
                          lambda i, x: pyro.observe("obs_%d" % i, x_dist, x),
                          batch_size=1)
            return mu_latent

        def guide():
            mu_q = pyro.param(
                "mu_q",
                Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2),
                         requires_grad=True))
            log_sig_q = pyro.param(
                "log_sig_q",
                Variable(self.analytic_log_sig_n.data - 0.09 * torch.ones(2),
                         requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            q_dist = DiagNormal(mu_q, sig_q)
            q_dist.reparametrized = reparametrized
            pyro.sample("mu_latent", q_dist)
            pyro.map_data(self.data, lambda i, x: None, batch_size=1)

        kl_optim = KL_QP(model, guide,
                         pyro.optim(torch.optim.Adam, {"lr": .001}))
        for k in range(n_steps):
            kl_optim.step()
        mu_error = torch.sum(
            torch.pow(self.analytic_mu_n - pyro.param("mu_q"), 2.0))
        log_sig_error = torch.sum(
            torch.pow(self.analytic_log_sig_n - pyro.param("log_sig_q"), 2.0))
        self.assertEqual(0.0, mu_error.data.cpu().numpy()[0], prec=0.02)
        self.assertEqual(0.0, log_sig_error.data.cpu().numpy()[0], prec=0.02)
Esempio n. 7
0
    def test_elbo_nonreparametrized(self):
        pyro._param_store._clear_cache()

        def model():
            lambda_latent = pyro.sample("lambda_latent",
                                        Gamma(self.alpha0, self.beta0))
            x_dist = Exponential(lambda_latent)
            pyro.observe("obs0", x_dist, self.data[0])
            pyro.observe("obs1", x_dist, self.data[1])
            return lambda_latent

        def guide():
            alpha_q_log = pyro.param(
                "alpha_q_log",
                Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
            beta_q_log = pyro.param(
                "beta_q_log",
                Variable(self.log_beta_n.data - 0.143, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            pyro.sample("lambda_latent", Gamma(alpha_q, beta_q))

        kl_optim = KL_QP(
            model, guide,
            pyro.optim(torch.optim.Adam, {
                "lr": .0003,
                "betas": (0.97, 0.999)
            }))
        for k in range(10001):
            kl_optim.step()

        alpha_error = torch.abs(pyro.param("alpha_q_log") -
                                self.log_alpha_n).data.cpu().numpy()[0]
        beta_error = torch.abs(pyro.param("beta_q_log") -
                               self.log_beta_n).data.cpu().numpy()[0]
        # print "alpha_error", alpha_error
        # print "beta_error", beta_error
        self.assertEqual(0.0, alpha_error, prec=0.05)
        self.assertEqual(0.0, beta_error, prec=0.05)
Esempio n. 8
0
    cll = pyro.sample("sample_cll", Categorical(alpha))
    # return img
    return img, img_mu, cll


def per_param_args(name, param):
    if name == "decoder":
        return {"lr": .0001}
    else:
        return {"lr": .0001}


# or alternatively
adam_params = {"lr": .0001}

inference = KL_QP(model_latent, guide_latent,
                  pyro.optim(optim.Adam, adam_params))
inference_c = KL_QP(model_given_c, guide_given_c,
                    pyro.optim(optim.Adam, adam_params))

mnist_data = Variable(train_loader.dataset.train_data.float() / 255.)
mnist_labels = Variable(train_loader.dataset.train_labels)
mnist_size = mnist_data.size(0)
batch_size = 128  # 64

# TODO: batches not necessarily
all_batches = np.arange(0, mnist_size, batch_size)

if all_batches[-1] != mnist_size:
    all_batches = list(all_batches) + [mnist_size]

vis = visdom.Visdom(env='vae_z_c')
Esempio n. 9
0
def model_sample(data, cll):
    classifier = pyro.module("classifier", pt_classify)
    alpha_cat = classifier.forward(data)
    cll = pyro.sample('observed_class', Categorical(alpha_cat))
    return cll


def guide(data, cll):
    return lambda foo: None


# or alternatively
adam_params = {"lr": .0001}

inference_opt = KL_QP(model_obs, guide, pyro.optim(optim.Adam, adam_params))

mnist_data = Variable(train_loader.dataset.train_data.float() / 255.)
mnist_labels = Variable(train_loader.dataset.train_labels)
mnist_size = mnist_data.size(0)
batch_size = 128  # 64

# TODO: batches not necessarily
all_batches = np.arange(0, mnist_size, batch_size)

if all_batches[-1] != mnist_size:
    all_batches = list(all_batches) + [mnist_size]

vis = visdom.Visdom()

for i in range(1000):
Esempio n. 10
0
def model_sample_given_class(cll=None):
    pass


def per_param_args(name, param):
    if name == "decoder":
        return {"lr": .0001}
    else:
        return {"lr": .0001}


# or alternatively
adam_params = {"lr": .0001}
# optim.SGD(lr=.0001)

inference_latent_class = KL_QP(model_latent, guide_latent, pyro.optim(optim.Adam, adam_params))
inference_observed_class = KL_QP(
    model_observed, guide_observed, pyro.optim(
        optim.Adam, adam_params))

inference_observed_class_scored = KL_QP(
    model_observed, guide_observed2, pyro.optim(
        optim.Adam, adam_params))

mnist_data = Variable(train_loader.dataset.train_data.float() / 255.)
mnist_labels = Variable(train_loader.dataset.train_labels)
mnist_size = mnist_data.size(0)
batch_size = 128  # 64

mnist_data_test = Variable(test_loader.dataset.test_data.float() / 255.)
mnist_labels_test = Variable(test_loader.dataset.test_labels)
Esempio n. 11
0
def inspect_posterior_samples(i):
    c = local_guide(i, None)
    mean_param = Variable(torch.zeros(784), requires_grad=True)
    # do MLE for class means
    m = pyro.param("mean_of_class_" + str(c[0]), mean_param)
    sigma = Variable(torch.ones(m.size()))
    dat = pyro.sample("obs_" + str(i), DiagNormal(m, sigma))
    return dat


#grad_step = ELBo(local_model, local_guide, model_ML=true, optimizer="adam")
optim_fct = pyro.optim(torch.optim.Adam, {'lr': .0001})

data = Variable(mnist.train_data).float() / 255.
nr_samples = data.size(0)
grad_step = KL_QP(local_model, local_guide, optim_fct)

d0 = inspect_posterior_samples(0)
d1 = inspect_posterior_samples(1)

vis = visdom.Visdom()

nr_epochs = 50
# apply it to minibatches of data by hand:
for epoch in range(nr_epochs):
    total_loss = 0.
    for i in range(nr_samples):
        # print('starting datum '+str(i))
        # mod_forward=local_model(i,data[i])
        # mod_inv=local_guide(i,data[i])
        loss_sample = grad_step(i, data[i])
                      DiagNormal(guide_mu_z, guide_sigma_z))

    w_q = pyro.sample("factor_weight", DiagNormal(guide_mu_q_w,
                                                  guide_sigma_q_w))

    return z_q, w_q


#grad_step = ELBo(local_model, local_guide, model_ML=true, optimizer="adam")
adam_params = {"lr": .00000000000001}
adam_optim = pyro.optim(torch.optim.Adam, adam_params)

data = Variable(mnist.train_data).float() / 255.
nr_samples = data.size(0)
nr_epochs = 1000
grad_step = KL_QP(local_model, local_guide, adam_optim)

# apply it to minibatches of data by hand:
for j in range(nr_epochs):
    score = 0
    for i in range(nr_batches):
        score_d = grad_step(i, data[i])
        score += score_d / float(nr_samples)
        print('Local Score ' + str(-score))
    print('Epoch score ' + str(-score))
    # bb()

    #print('starting datum '+str(i))
    # mod_forward=local_model(i,data[i])
    # mod_inv=local_guide(i,data[i])
Esempio n. 13
0
    return cll


def inspect_posterior_samples(i):
    cll = local_guide(i, None)
    mean_param = Variable(torch.zeros(1, 784), requires_grad=True)
    # do MLE for class means
    mu = pyro.param("mean_of_class_" + str(cll[0]), mean_param)
    dat = pyro.sample("obs_" + str(i), Bernoulli(mu))
    return dat


#grad_step = ELBo(local_model, local_guide, model_ML=true, optimizer="adam")
optim_fct = pyro.optim(torch.optim.Adam, {'lr': .0001})

inference = KL_QP(local_model, local_guide, optim_fct)

d0 = inspect_posterior_samples(0)
d1 = inspect_posterior_samples(1)

vis = visdom.Visdom()

nr_epochs = 50
# apply it to minibatches of data by hand:

mnist_data = Variable(train_loader.dataset.train_data.float() / 255.)
mnist_labels = Variable(train_loader.dataset.train_labels)
mnist_size = mnist_data.size(0)
batch_size = 1  # 64

all_batches = np.arange(0, mnist_size, batch_size)
Esempio n. 14
0
    # return img
    return img, img_mu


def per_param_args(name, param):
    if name == "decoder":
        return {"lr": .0001}
    else:
        return {"lr": .0001}


# or alternatively
adam_params = {"lr": .0001}
# optim.SGD(lr=.0001)

kl_optim = KL_QP(model, guide, pyro.optim(optim.Adam, adam_params))
#kl_optim = KL_QP(model, guide, pyro.optim(optim.Adam, per_param_args))

# num_steps = 1
mnist_data = Variable(train_loader.dataset.train_data.float() / 255.)
mnist_size = mnist_data.size(0)
batch_size = 512  # 64

# TODO: batches not necessarily
all_batches = np.arange(0, mnist_size, batch_size)

if all_batches[-1] != mnist_size:
    all_batches = list(all_batches) + [mnist_size]

vis = visdom.Visdom(env='vae_mnist')
Esempio n. 15
0
            alpha = sample("alpha", Exponential(lam=(Tensor([1]))))
        else:
            alpha = (Tensor([alpha]))
        return core(prefix, alpha, condition, generate)

    def guide(prefix, condition=False, generate=0, alpha=None):
        if alpha is None:
            alpha_point_estimate = softplus(param("alpha-point-estimate", (torch.ones(1), requires_grad=True)))
            alpha = sample("alpha", Delta(v=alpha_point_estimate))
        else:
            alpha = (Tensor([alpha]))
        return core(prefix, alpha, condition, generate)

    # Set up neural net parameter inference
    optimizer = pyro.optim(torch.optim.Adam, { "lr": .005, "betas": (0.97, 0.999) })
    infer = KL_QP(model, guide, optimizer)

    # Set up softmax alpha inference
    alpha_optimizer = pyro.optim(torch.optim.SGD, { "lr": .005, "momentum": 0.1 })
    char_names = ["char_{0}".format(i) for i in range(1000)]
    alpha_model = block(model, expose=["alpha", "alpha-point-estimate"] + char_names)
    alpha_guide = block(guide, expose=["alpha", "alpha-point-estimate"] + char_names)
    alpha_infer = KL_QP(alpha_model, alpha_guide, alpha_optimizer)

    for k in range(10000000):

        # Draw a random text sample
        chunk = dataset.random_chunk(chunk_len=200)

        # Fit alpha to current text sample, keeping neural net params fixed
        prev_alpha_loss = float("-inf")
Esempio n. 16
0
# and try importance!


def guide():
    latent = pyro.sample(
        "latent",
        DiagNormal(Variable(torch.zeros(1)), 5 * Variable(torch.ones(1))))
    x_dist = DiagNormal(latent, Variable(torch.ones(1)))
    pass


infer = ImportanceSampling(model, guide)

exp = lw_expectation(infer, lambda x: x, 100)
print(exp)


# and try VI!
def guide():
    mf_m = pyro.param("mf_m", Variable(torch.zeros(1)))
    mf_v = pyro.param("mf_v", Variable(torch.ones(1)))
    latent = pyro.sample("latent", DiagNormal(mf_m, mf_v))
    pass


infer = KL_QP(model, guide)

exp = lw_expectation(infer, lambda x: x, 100)
print(exp)