Exemple #1
0
def test_shape_augmented_gamma_elbo(alpha, beta):
    num_samples = 100000
    alphas = torch.full((num_samples, 1), alpha).requires_grad_()
    betas = torch.full((num_samples, 1), beta).requires_grad_()

    model = Gamma(torch.ones(num_samples, 1), torch.ones(num_samples, 1))
    guide1 = Gamma(alphas, betas)
    guide2 = ShapeAugmentedGamma(alphas, betas)  # implemented using Rejector

    grads = []
    for guide in [guide1, guide2]:
        grads.append(compute_elbo_grad(model, guide, [alphas, betas]))
    expected, actual = grads
    expected = [g.mean() for g in expected]
    actual = [g.mean() for g in actual]
    scale = [(1 + abs(g)) for g in expected]
    assert_equal(
        actual[0] / scale[0],
        expected[0] / scale[0],
        prec=0.05,
        msg="bad grad for alpha",
    )
    assert_equal(actual[1] / scale[1],
                 expected[1] / scale[1],
                 prec=0.05,
                 msg="bad grad for beta")
Exemple #2
0
 def init(self, state, N):
     state["λ"] = sample("λ", Gamma(1., 1./1.))
     state["μ"] = sample("μ", Gamma(1., 1./0.5))
     f = (N-1)*log(tensor(2))
     for n in range(2, N+1):
         f -= log(tensor(n))
     factor("factor_orient_labeled", f)
Exemple #3
0
def test_standard_gamma_elbo(alpha):
    num_samples = 100000
    alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True)
    betas = torch.ones(num_samples, 1)

    model = Gamma(torch.ones(num_samples, 1), betas)
    guide1 = Gamma(alphas, betas)
    guide2 = RejectionStandardGamma(alphas)  # implemented using Rejector

    grads = []
    for guide in [guide1, guide2]:
        grads.append(compute_elbo_grad(model, guide, [alphas])[0].data)
    expected, actual = grads
    assert_equal(actual.mean(), expected.mean(), prec=0.01, msg='bad grad for alpha')
def main(data=None, args=None, batch_size=None):
    data = torch.reshape(data, [64, 1000])

    # Globals.
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(torch.ones(1024) / 1024))

    # Locals.
    with pyro.plate("documents", 1000) as ind:
        # shape = [64, 32]
        data = data[:, ind]
        # shape = [32] + [8]
        doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights))

        with pyro.plate("words", 64):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            # shape = [64, 32] + []
            word_topics = pyro.sample("word_topics",
                                      Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            # shape = [64, 32] + []
            data = pyro.sample("doc_words",
                               Categorical(topic_words[word_topics]),
                               obs=data)
def generate_model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", 8):
        topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.))
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(torch.ones(1024) / 1024))

    # Locals.
    with pyro.plate("documents", 1000) as ind:
        if data is not None:
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights))
        with pyro.plate("words", 64):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics",
                                      Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words",
                               Categorical(topic_words[word_topics]),
                               obs=data)

    return topic_weights, topic_words, data
 def sample_ws(name, width):
     alpha_w_q = pyro.param("alpha_w_q_%s" % name,
                            lambda: rand_tensor((width), self.alpha_init, self.sigma_init))
     mean_w_q = pyro.param("mean_w_q_%s" % name,
                           lambda: rand_tensor((width), self.mean_init, self.sigma_init))
     alpha_w_q, mean_w_q = softplus(alpha_w_q), softplus(mean_w_q)
     pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q))
 def sample_zs(name, width):
     alpha_z_q = pyro.param("alpha_z_q_%s" % name,
                            lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init))
     mean_z_q = pyro.param("mean_z_q_%s" % name,
                           lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init))
     alpha_z_q, mean_z_q = softplus(alpha_z_q), softplus(mean_z_q)
     pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
Exemple #8
0
 def model():
     lambda_latent = pyro.sample("lambda_latent",
                                 Gamma(self.alpha0, self.beta0))
     x_dist = Exponential(lambda_latent)
     pyro.observe("obs0", x_dist, self.data[0])
     pyro.observe("obs1", x_dist, self.data[1])
     return lambda_latent
def model(n_samples=None):
    with pyro.plate('observations', n_samples):
        thickness = pyro.sample('thickness', Gamma(10., 5.))

        loc = (thickness - 2.5) * 20
        slant = pyro.sample('slant', Normal(loc, 1.))

    return slant, thickness
Exemple #10
0
 def guide():
     alpha_q_log = pyro.param(
         "alpha_q_log", Variable(self.alpha_q_log_0,
                                 requires_grad=True))
     beta_q_log = pyro.param(
         "beta_q_log", Variable(self.beta_q_log_0, requires_grad=True))
     alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
     pyro.sample("lambda_latent", Gamma(alpha_q, beta_q))
Exemple #11
0
 def model():
     lambda_latent = pyro.sample("lambda_latent",
                                 Gamma(self.alpha0, self.beta0))
     x_dist = Poisson(lambda_latent)
     # x0 = pyro.observe("obs0", x_dist, self.data[0])
     pyro.map_data(self.data,
                   lambda i, x: pyro.observe("obs", x_dist, x),
                   batch_size=3)
     return lambda_latent
Exemple #12
0
def test_gamma_elbo(alpha, beta):
    num_samples = 100000
    alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True)
    betas = torch.tensor(torch.tensor(beta).expand(num_samples, 1), requires_grad=True)

    model = Gamma(torch.ones(num_samples, 1), torch.ones(num_samples, 1))
    guide1 = Gamma(alphas, betas)
    guide2 = RejectionGamma(alphas, betas)  # implemented using Rejector

    grads = []
    for guide in [guide1, guide2]:
        grads.append(compute_elbo_grad(model, guide, [alphas, betas]))
    expected, actual = grads
    expected = [g.mean() for g in expected]
    actual = [g.mean() for g in actual]
    scale = [(1 + abs(g)) for g in expected]
    assert_equal(actual[0] / scale[0], expected[0] / scale[0], prec=0.01, msg='bad grad for alpha')
    assert_equal(actual[1] / scale[1], expected[1] / scale[1], prec=0.01, msg='bad grad for beta')
Exemple #13
0
 def guide():
     alpha_q_log = pyro.param(
         "alpha_q_log",
         Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
     beta_q_log = pyro.param(
         "beta_q_log",
         Variable(self.log_beta_n.data - 0.143, requires_grad=True))
     alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
     pyro.sample("lambda_latent", Gamma(alpha_q, beta_q))
     pyro.map_data(self.data, lambda i, x: None, batch_size=3)
Exemple #14
0
    def init(self, state, λ, s0, i0, r0):
        obs = lambda value: full((state._num_particles, ), value)

        state["λ"] = sample("λ", Gamma(2., 1. / 5.0), obs=obs(λ))
        state["δ"] = sample("δ", Beta(2., 2.))
        state["γ"] = sample("γ", Beta(2., 2.))

        state["s0"] = obs(s0)
        state["i0"] = obs(i0)
        state["r0"] = obs(r0)
def model(n_samples=None, scale=2.):
    with pyro.plate('observations', n_samples):
        thickness = pyro.sample('thickness', Gamma(10., 5.))

        loc = (thickness - 2.5) * 2

        transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)])

        width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms))

    return thickness, width
Exemple #16
0
 def model():
     alpha_p_log = pyro.param(
         "alpha_p_log", Variable(self.alpha_p_log_0,
                                 requires_grad=True))
     beta_p_log = pyro.param(
         "beta_p_log", Variable(self.beta_p_log_0, requires_grad=True))
     alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log)
     lambda_latent = pyro.sample("lambda_latent",
                                 Gamma(alpha_p, beta_p))
     x_dist = Poisson(lambda_latent)
     pyro.observe("obs", x_dist, self.data)
     return lambda_latent
    def model(self, x):
        x_size = x.size(0)

        # sample the global weights
        with pyro.plate("w_top_plate", self.top_width * self.mid_width):
            w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w))
        with pyro.plate("w_mid_plate", self.mid_width * self.bottom_width):
            w_mid = pyro.sample("w_mid", Gamma(self.alpha_w, self.beta_w))
        with pyro.plate("w_bottom_plate", self.bottom_width * self.image_size):
            w_bottom = pyro.sample("w_bottom", Gamma(self.alpha_w, self.beta_w))

        # sample the local latent random variables
        # (the plate encodes the fact that the z's for different datapoints are conditionally independent)
        with pyro.plate("data", x_size):
            z_top = pyro.sample("z_top", Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1))
            # note that we need to use matmul (batch matrix multiplication) as well as appropriate reshaping
            # to make sure our code is fully vectorized
            w_top = w_top.reshape(self.top_width, self.mid_width) if w_top.dim() == 1 else \
                w_top.reshape(-1, self.top_width, self.mid_width)
            mean_mid = torch.matmul(z_top, w_top)
            z_mid = pyro.sample("z_mid", Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1))

            w_mid = w_mid.reshape(self.mid_width, self.bottom_width) if w_mid.dim() == 1 else \
                w_mid.reshape(-1, self.mid_width, self.bottom_width)
            mean_bottom = torch.matmul(z_mid, w_mid)
            z_bottom = pyro.sample("z_bottom", Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1))

            w_bottom = w_bottom.reshape(self.bottom_width, self.image_size) if w_bottom.dim() == 1 else \
                w_bottom.reshape(-1, self.bottom_width, self.image_size)
            mean_obs = torch.matmul(z_bottom, w_bottom)

            # observe the data using a poisson likelihood
            pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x)
def model(x):
    x = torch.reshape(x, [320, 4096])

    with pyro.plate("w_top_plate", 4000):
        w_top = pyro.sample("w_top", Gamma(alpha_w, beta_w))
    with pyro.plate("w_mid_plate", 600):
        w_mid = pyro.sample("w_mid", Gamma(alpha_w, beta_w))
    with pyro.plate("w_bottom_plate", 61440):
        w_bottom = pyro.sample("w_bottom", Gamma(alpha_w, beta_w))

    with pyro.plate("data", 320):
        z_top = pyro.sample(
            "z_top",
            Gamma(alpha_z, beta_z).expand_by([100]).to_event(1))

        w_top = torch.reshape(w_top, [100, 40])
        mean_mid = torch.matmul(z_top, w_top)
        z_mid = pyro.sample("z_mid",
                            Gamma(alpha_z, beta_z / mean_mid).to_event(1))

        w_mid = torch.reshape(w_mid, [40, 15])
        mean_bottom = torch.matmul(z_mid, w_mid)
        z_bottom = pyro.sample(
            "z_bottom",
            Gamma(alpha_z, beta_z / mean_bottom).to_event(1))

        w_bottom = torch.reshape(w_bottom, [15, 4096])
        mean_obs = torch.matmul(z_bottom, w_bottom)

        pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x)
Exemple #19
0
    def model(self, x):
        x_size = x.size(0)

        # sample the global weights
        with pyro.iarange("w_top_iarange", self.top_width * self.mid_width):
            w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w))
        with pyro.iarange("w_mid_iarange", self.mid_width * self.bottom_width):
            w_mid = pyro.sample("w_mid", Gamma(self.alpha_w, self.beta_w))
        with pyro.iarange("w_bottom_iarange",
                          self.bottom_width * self.image_size):
            w_bottom = pyro.sample("w_bottom", Gamma(self.alpha_w,
                                                     self.beta_w))

        # sample the local latent random variables
        # (the iarange encodes the fact that the z's for different datapoints are conditionally independent)
        with pyro.iarange("data", x_size):
            z_top = pyro.sample(
                "z_top",
                Gamma(self.alpha_z,
                      self.beta_z).expand([self.top_width]).independent(1))
            mean_mid = torch.mm(z_top,
                                w_top.reshape(self.top_width, self.mid_width))
            z_mid = pyro.sample(
                "z_mid",
                Gamma(self.alpha_z, self.beta_z / mean_mid).independent(1))
            mean_bottom = torch.mm(
                z_mid, w_mid.view(self.mid_width, self.bottom_width))
            z_bottom = pyro.sample(
                "z_bottom",
                Gamma(self.alpha_z, self.beta_z / mean_bottom).independent(1))
            mean_obs = torch.mm(
                z_bottom, w_bottom.view(self.bottom_width, self.image_size))

            # observe the data using a poisson likelihood
            pyro.sample('obs', Poisson(mean_obs).independent(1), obs=x)
Exemple #20
0
def model(data=None, args=None, batch_size=None):
    if debug: print("model:")
    data = torch.reshape(data, [64, 1000])

    # Globals.
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(torch.ones(1024) / 1024))
        if debug:
            print("topic_weights\t: shape={}, sum={}".format(
                topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Locals.
    # with pyro.plate("documents", 1000) as ind:
    with pyro.plate("documents", 1000, 32, dim=-1) as ind:
        # if data is not None:
        #     data = data[:, ind]
        # shape = [64, 32]
        data = data[:, ind]
        # shape = [32] + [8]
        doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights))
        if debug:
            print("data\t\t: shape={}".format(data.shape))
            print("doc_topics\t: shape={}, [0].sum={}".format(
                doc_topics.shape, torch.sum(doc_topics[0])))

        # with pyro.plate("words", 64):
        with pyro.plate("words", 64, dim=-2):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            # shape = [64, 32] + []
            word_topics =\
                pyro.sample("word_topics", Categorical(doc_topics),
                            infer={"enumerate": "parallel"})
            # pyro.sample("word_topics", Categorical(doc_topics))
            # shape = [64, 32] + []
            data =\
                pyro.sample("doc_words", Categorical(topic_words[word_topics]),
                            obs=data)
            if debug:
                print("word_topics\t: shape={}".format(word_topics.shape))
                print("data\t\t: shape={}".format(data.shape))

    return topic_weights, topic_words, data
Exemple #21
0
def main(data, args, batch_size=None):
    data = torch.reshape(data, [64, 1000])
    
    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        torch.ones(8),
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # WL: edited. =====
        #torch.ones(8, 1024),
        #constraint=constraints.greater_than(0.5)
        torch.ones(8, 1024) * 0.5,
        constraint=constraints.positive
        # =================
    )
    
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        # WL: edited. =====
        #topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior))
        topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior + 0.5))
        # =================

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        # shape = [32] + [8]
        doc_topics = softmax(doc_topics_w)
        pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
Exemple #22
0
def model(n_samples=None, scale=0.5, invert=False):
    with pyro.plate('observations', n_samples):
        thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.))

        if invert:
            loc = (thickness - 2) * -2
        else:
            loc = (thickness - 2.5) * 2

        transforms = ComposeTransform(
            [SigmoidTransform(), AffineTransform(64, 191)])

        intensity = pyro.sample(
            'intensity', TransformedDistribution(Normal(loc, scale),
                                                 transforms))

    return thickness, intensity
 def _gamma(self):
     concentration = self.theta
     rate = self.theta / self.mu
     # Important remark: Gamma is parametrized by the rate = 1/scale!
     gamma_d = Gamma(concentration=concentration, rate=rate)
     return gamma_d
def guide(x):
    x = torch.reshape(x, [320, 4096])

    with pyro.plate("w_top_plate", 4000):
        #============ sample_ws
        alpha_w_q =\
            pyro.param("log_alpha_w_q_top",
                       alpha_init * torch.ones(4000) +
                       sigma_init * torch.randn(4000))
        mean_w_q =\
            pyro.param("log_mean_w_q_top",
                       mean_init * torch.ones(4000) +
                       sigma_init * torch.randn(4000))
        alpha_w_q = softplus(alpha_w_q)
        mean_w_q = softplus(mean_w_q)
        pyro.sample("w_top", Gamma(alpha_w_q, alpha_w_q / mean_w_q))
        #============ sample_ws

    with pyro.plate("w_mid_plate", 600):
        #============ sample_ws
        alpha_w_q =\
            pyro.param("log_alpha_w_q_mid",
                       alpha_init * torch.ones(600) +
                       sigma_init * torch.randn(600))
        mean_w_q =\
            pyro.param("log_mean_w_q_mid",
                       mean_init * torch.ones(600) +
                       sigma_init * torch.randn(600))
        alpha_w_q = softplus(alpha_w_q)
        mean_w_q = softplus(mean_w_q)
        pyro.sample("w_mid", Gamma(alpha_w_q, alpha_w_q / mean_w_q))
        #============ sample_ws

    with pyro.plate("w_bottom_plate", 61440):
        #============ sample_ws
        alpha_w_q =\
            pyro.param("log_alpha_w_q_bottom",
                       alpha_init * torch.ones(61440) +
                       sigma_init * torch.randn(61440))
        mean_w_q =\
            pyro.param("log_mean_w_q_bottom",
                       mean_init * torch.ones(61440) +
                       sigma_init * torch.randn(61440))
        alpha_w_q = softplus(alpha_w_q)
        mean_w_q = softplus(mean_w_q)
        pyro.sample("w_bottom", Gamma(alpha_w_q, alpha_w_q / mean_w_q))
        #============ sample_ws

    with pyro.plate("data", 320):
        #============ sample_zs
        alpha_z_q =\
            pyro.param("log_alpha_z_q_top",
                       alpha_init * torch.ones(320, 100) +
                       sigma_init * torch.randn(320, 100))
        mean_z_q =\
            pyro.param("log_mean_z_q_top",
                       mean_init * torch.ones(320, 100) +
                       sigma_init * torch.randn(320, 100))
        alpha_z_q = softplus(alpha_z_q)
        mean_z_q = softplus(mean_z_q)
        pyro.sample("z_top",
                    Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
        #============ sample_zs
        #============ sample_zs
        alpha_z_q =\
            pyro.param("log_alpha_z_q_mid",
                       alpha_init * torch.ones(320, 40) +
                       sigma_init * torch.randn(320, 40))
        mean_z_q =\
            pyro.param("log_mean_z_q_mid",
                       mean_init * torch.ones(320, 40) +
                       sigma_init * torch.randn(320, 40))
        alpha_z_q = softplus(alpha_z_q)
        mean_z_q = softplus(mean_z_q)
        pyro.sample("z_mid",
                    Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
        #============ sample_zs
        #============ sample_zs
        alpha_z_q =\
            pyro.param("log_alpha_z_q_bottom",
                       alpha_init * torch.ones(320, 15) +
                       sigma_init * torch.randn(320, 15))
        mean_z_q =\
            pyro.param("log_mean_z_q_bottom",
                       mean_init * torch.ones(320, 15) +
                       sigma_init * torch.randn(320, 15))
        alpha_z_q = softplus(alpha_z_q)
        mean_z_q = softplus(mean_z_q)
        pyro.sample("z_bottom",
                    Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
Exemple #25
0
def guide(data, args, batch_size=None):
    if debug: print("guide:")
    data = torch.reshape(data, [64, 1000])

    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param("topic_weights_posterior",
                                         torch.ones(8) / 8,
                                         constraint=constraints.positive)
    topic_words_posterior = pyro.param("topic_words_posterior",
                                       torch.ones(8, 1024) / 1024,
                                       constraint=constraints.positive)

    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights",
                                    Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(topic_words_posterior))
        if debug:
            print("topic_weights\t: shape={}, sum={}".format(
                topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32, dim=-1) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        if debug:
            print("counts.shape={}, counts_trans.shape={}".format(
                counts.shape,
                torch.transpose(counts, 0, 1).shape))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics = sigmoid(layer3(h2))
        if debug:
            print("doc_topics(nn result)\t: shape={}, [0].sum={}".format(
                doc_topics.shape, torch.sum(doc_topics[0])))
            d = Dirichlet(doc_topics)
            print("Dirichlet(doc_topics)\t: batch_shape={}, event_shape={}".
                  format(d.batch_shape, d.event_shape))
        """
        NOTE: There are three options we can take:
        1. sample doc_topics (from *DELTA*) *ONLY*. (original code)
        2. sample doc_topics (from *DIRICHLET*) *ONLY*.
        3. sample doc_topics (from *DIRICHLET*), word_topics (from Categorical) *BOTH*.

        Expected results:
          | Trace_ELBO | TraceEnum_ELBO 
        -------------------------------
        1 | FAIL       | FAIL
        2 | FAIL       | PASS
        3 | PASS       | FAIL
        """
        # shape = [32] + [8]
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics).to_event(1))
        doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics))
        if debug:
            print("doc_topics(sampled)\t: shape={}, [0].sum = {}".format(
                doc_topics.shape, torch.sum(doc_topics[0])))
        with pyro.plate("words", 64, dim=-2):
            # shape = [64, 32, 8]
            word_topics_posterior = doc_topics * topic_words.transpose(
                0, 1)[data, :]
            word_topics_posterior = word_topics_posterior / (
                word_topics_posterior.sum(dim=-1, keepdim=True))
            word_topics =\
                pyro.sample("word_topics", Categorical(word_topics_posterior))
            if debug:
                print("word_tpics_posterior\t: shape={}".format(
                    word_topics_posterior.shape))
                print("word_topics(sampled)\t: shape={}".format(
                    word_topics.shape))
                d = Categorical(word_topics_posterior)
                print(
                    "Categorical(word_topics_posterior)\t: batch_shape={}, event_shape={}"
                    .format(d.batch_shape, d.event_shape))
Exemple #26
0
 def predict(self, x) -> Independent:
     rate, conc = self(x)
     event_ndim = len(rate.shape[1:])  # keep only batch dimension
     return Gamma(rate, conc).to_event(event_ndim)
Exemple #27
0
def test_log_prob(concentration, rate, value):
    value = torch.tensor(value)
    log_prob = InverseGamma(concentration, rate).log_prob(value)
    expected_log_prob = (Gamma(concentration, rate).log_prob(1.0 / value) -
                         2.0 * value.log())
    assert_equal(log_prob, expected_log_prob, prec=1e-6)
Exemple #28
0
def guide(data, args, batch_size=None):
    if debug: print("guide:")
    data = torch.reshape(data, [64, 1000])

    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        # lambda: torch.ones(8) / 8,
        torch.ones(8) / 8,
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # lambda: torch.ones(8, 1024) / 1024,
        torch.ones(8, 1024) / 1024,
        constraint=constraints.positive)
    """
    # wy: dummy param for word_topics
    word_topics_posterior = pyro.param(
        "word_topics_posterior",
        torch.ones(64, 1024, 8) / 8,
        constraint=constraints.positive)
    """

    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights",
                                    Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(topic_words_posterior))
        if debug:
            print("topic_weights\t: shape={}, sum={}".format(
                topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        if debug:
            print("counts.shape={}, counts_trans.shape={}".format(
                counts.shape,
                torch.transpose(counts, 0, 1).shape))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        if debug:
            print("doc_topics_w(nn result)\t: shape={}, [0].sum={}".format(
                doc_topics_w.shape, torch.sum(doc_topics_w[0])))
        # shape = [32] + [8]
        # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1))
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1))
        doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w))
        """
def guide(data, args, batch_size=None):
    if debug: print("guide:")
    data = torch.reshape(data, [64, 1000])
    
    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        # lambda: torch.ones(8) / 8,
        torch.ones(8) / 8,
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # lambda: torch.ones(8, 1024) / 1024,
        torch.ones(8, 1024) / 1024,
        constraint=constraints.positive)
    """
    # wy: dummy param for word_topics
    word_topics_posterior = pyro.param(
        "word_topics_posterior",
        torch.ones(64, 1024, 8) / 8,
        constraint=constraints.positive)
    """
    
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior))
        if debug:
            print("topic_weights\t: shape={}, sum={}".
                  format(topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        if debug:
            print("counts.shape={}, counts_trans.shape={}".
                  format(counts.shape, torch.transpose(counts, 0, 1).shape))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        if debug:
            print("doc_topics_w(nn result)\t: shape={}, [0].sum={}".
                  format(doc_topics_w.shape, torch.sum(doc_topics_w[0])))
            d = Dirichlet(doc_topics_w)
            print("Dirichlet(doc_topics_w)\t: batch_shape={}, event_shape={}".
                    format(d.batch_shape, d.event_shape))
        # shape = [32] + [8]
        # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1))
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1))
        doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w))

        # wy: sample from exact posterior of word_topics
        with pyro.plate("words", 64):
            # ks : [K, D] = [8, 32]
            # ks = torch.arange(0,8).expand(32,8).transpose(0,1)
            ks = torch.arange(0, 8)
            ks = torch.Tensor.expand(ks, 32, 8)
            ks = torch.Tensor.transpose(ks, 0, 1)
            # logprob1 : [N, D, K] = [32, 8]
            # logprob1 = Categorical(doc_topics).log_prob(ks).transpose(0,1).expand(64,32,8)
            logprob1 = Categorical.log_prob(Categorical(doc_topics), ks)
            logprob1 = torch.Tensor.transpose(logprob1, 0, 1)
            logprob1 = torch.Tensor.expand(logprob1, 64, 32, 8)
            # data2 : [N, D, K] = [64, 32, 8]
            # data2 = data.expand(8,64,32).transpose(0,1).transpose(1,2)
            data2 = torch.Tensor.expand(data, 8, 64, 32)
            data2 = torch.Tensor.transpose(data2, 0, 1)
            data2 = torch.Tensor.transpose(data2, 1, 2)
            # logprob2 : [N, D, K] = [64, 32, 8]
            # logprob2 = Categorical(topic_words).log_prob(data2)
            logprob2 = Categorical.log_prob(Categorical(topic_words), data2)
            # prob : [N, D, K] = [64, 32, 8]
            prob = torch.exp(logprob1 + logprob2)
            # word_topics : [N, D] = [64, 32]
            word_topics = pyro.sample("word_topics", Categorical(prob))

        """