Ejemplo n.º 1
0
def guide_t0(data):
    # T-1 alpha params for beta sampling
    kappa = pyro.param('kappa',
                       lambda: Uniform(0, 2).sample([T - 1]),
                       constraint=constraints.positive)

    # concentration params for q_theta #[T,C]
    tau = pyro.param('tau',
                     lambda: MultivariateNormal(0.5 * torch.ones(C), 0.25 *
                                                torch.eye(C)).sample([T]),
                     constraint=constraints.unit_interval)

    # N params for categorical dist; topic weights; symmetric prior
    phi = pyro.param('phi',
                     lambda: Dirichlet(1 / T * torch.ones(T)).sample([N]),
                     constraint=constraints.simplex)

    with pyro.plate("beta_plate", T - 1):
        q_beta = 0
        q_beta += pyro.sample("beta", Beta(torch.ones(T - 1), kappa))
        # q_beta *= 1

    # sample probs for multinomial distributions
    with pyro.plate("theta_plate", T):
        # outputs multinomial probabilities for each topic
        q_theta = 0
        q_theta += pyro.sample("theta", Dirichlet(tau))
        # q_theta *= 1

    with pyro.plate("data", N):
        z = 0
        z += pyro.sample("z", Categorical(phi))
def generate_model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", 8):
        topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.))
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(torch.ones(1024) / 1024))

    # Locals.
    with pyro.plate("documents", 1000) as ind:
        if data is not None:
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights))
        with pyro.plate("words", 64):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics",
                                      Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words",
                               Categorical(topic_words[word_topics]),
                               obs=data)

    return topic_weights, topic_words, data
def main(data=None, args=None, batch_size=None):
    data = torch.reshape(data, [64, 1000])

    # Globals.
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(torch.ones(1024) / 1024))

    # Locals.
    with pyro.plate("documents", 1000) as ind:
        # shape = [64, 32]
        data = data[:, ind]
        # shape = [32] + [8]
        doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights))

        with pyro.plate("words", 64):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            # shape = [64, 32] + []
            word_topics = pyro.sample("word_topics",
                                      Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            # shape = [64, 32] + []
            data = pyro.sample("doc_words",
                               Categorical(topic_words[word_topics]),
                               obs=data)
Ejemplo n.º 4
0
def model(data=None, args=None, batch_size=None):
    if debug: print("model:")
    data = torch.reshape(data, [64, 1000])

    # Globals.
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(torch.ones(1024) / 1024))
        if debug:
            print("topic_weights\t: shape={}, sum={}".format(
                topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Locals.
    # with pyro.plate("documents", 1000) as ind:
    with pyro.plate("documents", 1000, 32, dim=-1) as ind:
        # if data is not None:
        #     data = data[:, ind]
        # shape = [64, 32]
        data = data[:, ind]
        # shape = [32] + [8]
        doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights))
        if debug:
            print("data\t\t: shape={}".format(data.shape))
            print("doc_topics\t: shape={}, [0].sum={}".format(
                doc_topics.shape, torch.sum(doc_topics[0])))

        # with pyro.plate("words", 64):
        with pyro.plate("words", 64, dim=-2):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            # shape = [64, 32] + []
            word_topics =\
                pyro.sample("word_topics", Categorical(doc_topics),
                            infer={"enumerate": "parallel"})
            # pyro.sample("word_topics", Categorical(doc_topics))
            # shape = [64, 32] + []
            data =\
                pyro.sample("doc_words", Categorical(topic_words[word_topics]),
                            obs=data)
            if debug:
                print("word_topics\t: shape={}".format(word_topics.shape))
                print("data\t\t: shape={}".format(data.shape))

    return topic_weights, topic_words, data
Ejemplo n.º 5
0
def model_t0(data):
    with pyro.plate("beta_plate", T - 1):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("theta_plate", T):  # shape [T,C]
        # sample probabilities for Mult dist.; Dirichlet with symmetric prior
        theta = pyro.sample("theta", Dirichlet(torch.ones(C) / C))

    with pyro.plate("data", N):
        # z==which topic
        z = pyro.sample("z", Categorical(probs=mix_weights(beta)))
        pyro.sample("obs", Multinomial(probs=theta[z]), obs=data)
Ejemplo n.º 6
0
def model(data):
    # whether new topic or not; prior=0.5; random choice whether old/new
    with pyro.plate("new_topic_plate", T):
        new_topic = pyro.sample("new_topic", Binomial(probs=0.5))

    # if new topic, if linked to old topic, prior=0.5
    with pyro.plate("linked_plate", T):
        linked = pyro.sample("linked", Binomial(probs=0.5))

    # if old topic, which old topic
    with pyro.plate("old_topic_plate", T):
        which_old_topic = pyro.sample("which_old_topic",
                                      Multinomial(probs=prev_topic_freq))

    # beta sampling for topic weights
    with pyro.plate("beta_plate", T - 1):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("theta_plate", T):  # shape [T,C]
        # Dirichlet distribution (conjugate prior of Mult); symmetric prior
        theta = pyro.sample("theta", Dirichlet(torch.ones(C) / C))

    with pyro.plate("gamma_plate", T_prev):
        gamma = pyro.sample("gamma", Dirichlet(prev_taus))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(probs=mix_weights(beta)))
        old = get_old_topics(which_old_topic)
        a = (new_topic) * (linked)
        b = (1 - new_topic)
        c = (new_topic) * (1 - linked)
        a = a[z].reshape(N, 1)
        b = b[z].reshape(N, 1)
        c = c[z].reshape(N, 1)
        mult_probs = a * gamma[old[z]] + b * prev_theta[old[z]] + c * theta[z]
        pyro.sample("obs", Multinomial(probs=mult_probs), obs=data)
Ejemplo n.º 7
0
def main(data, args, batch_size=None):
    data = torch.reshape(data, [64, 1000])
    
    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        torch.ones(8),
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # WL: edited. =====
        #torch.ones(8, 1024),
        #constraint=constraints.greater_than(0.5)
        torch.ones(8, 1024) * 0.5,
        constraint=constraints.positive
        # =================
    )
    
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        # WL: edited. =====
        #topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior))
        topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior + 0.5))
        # =================

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        # shape = [32] + [8]
        doc_topics = softmax(doc_topics_w)
        pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
def guide(sequences):
    theta = pyro.param("theta", torch.ones(16))
    alpha = pyro.param("alpha", torch.rand(1))
    beta = pyro.param("beta", torch.rand(1))
    p = pyro.param("p", torch.rand(1))
    q = pyro.param("q", torch.rand(1))
    w = p * torch.eye(16) + q
    with poutine.mask(mask=False):
        probs_x = pyro.sample("probs_x", Dirichlet(w).to_event(1))
        probs_y = pyro.sample("probs_y",
                              Beta(alpha, beta).expand([16, 51]).to_event(2))

    for i in pyro.plate("sequences", len(sequences), 8):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample("x_{}_{}".format(i, t), Categorical(probs_x[x]))
def model(sequences):
    with poutine.mask(mask=False):
        probs_x = pyro.sample("probs_x",
                              Dirichlet(0.9 * torch.eye(16) + 0.1).to_event(1))
        probs_y = pyro.sample("probs_y",
                              Beta(0.1, 0.9).expand([16, 51]).to_event(2))
    tones_plate = pyro.plate("tones", 51, dim=-1)
    for i in pyro.plate("sequences", len(sequences)):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample("x_{}_{}".format(i, t),
                            Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
Ejemplo n.º 10
0
def naive_acceptability_and_central_weight(number_criterion,
                                           number_alternatives,
                                           number_iterations):
    central_weight_vector = torch.zeros(
        [number_alternatives,
         number_criterion])  # at i,j for alternative i at coordinate j
    weight_shape = torch.ones([
        number_criterion,
    ])
    count_matrix = torch.zeros([number_alternatives, number_alternatives
                                ])  # at i,j for alternative i ranked j-th

    for _ in range(number_iterations):
        weights = Dirichlet(weight_shape).sample()
        crit_alt_mat = crit_alt_matrix(
            number_alternatives,
            number_criterion)  # at i,j for alternative i against criterion j
        rank_vector = [
            rank(i, crit_alt_mat, weights) for i in range(number_alternatives)
        ]  # best rank is 0
        for i in range(number_alternatives):
            count_matrix[i, rank_vector[i]] += 1
            if rank_vector[i] == 0:
                central_weight_vector[i] += weights
    acceptability_index = torch.zeros_like(
        count_matrix
    )  # at i,j for approx proba of alternative i should be ranked j-th

    for i in range(number_alternatives):
        if count_matrix[i, 0] > 0:
            central_weight_vector[i] /= count_matrix[
                i, 0]  # average vector ranking alternative i on top
        for j in range(number_alternatives):
            acceptability_index[i, j] = count_matrix[
                i,
                j] / number_iterations  # approx proba for alternative i should be on top
    return central_weight_vector, acceptability_index
Ejemplo n.º 11
0
def guide(data):
    # pyro params
    new_topic_prob = pyro.param("new_topic_prob",
                                lambda: Uniform(0, 1).sample([T]),
                                constraint=constraints.unit_interval)

    linked_prob = pyro.param("linked_prob",
                             lambda: Uniform(0, 1).sample([T]),
                             constraint=constraints.unit_interval)

    which_topic_probs = pyro.param("which_topic_probs",
                                   lambda: Uniform(0, 1).sample([T_prev]),
                                   constraint=constraints.simplex)

    kappa = pyro.param('kappa',
                       lambda: Uniform(0, 2).sample([T - 1]),
                       constraint=constraints.positive)

    tau = pyro.param('tau',
                     lambda: MultivariateNormal(0.5 * torch.ones(C), 0.25 *
                                                torch.eye(C)).sample([T]),
                     constraint=constraints.unit_interval)

    # N params for categorical dist; topic weights; symmetric prior
    phi = pyro.param('phi',
                     lambda: Dirichlet(1 / T * torch.ones(T)).sample([N]),
                     constraint=constraints.simplex)

    # model params
    with pyro.plate("new_topic_plate", T):
        # print(new_topic_prob)
        new_topic = pyro.sample("new_topic", Binomial(probs=new_topic_prob))

    # if new topic, if linked to old topic, prior=0.5
    with pyro.plate("linked_plate", T):
        linked = pyro.sample("linked", Binomial(probs=linked_prob))

    # if old topic, which old topic
    with pyro.plate("old_topic_plate", T):
        which_old_topic = pyro.sample("which_old_topic",
                                      Multinomial(probs=which_topic_probs))

    with pyro.plate("beta_plate", T - 1):
        q_beta = 0
        q_beta += pyro.sample("beta", Beta(torch.ones(T - 1), kappa))

    # new topic with symmetric prior
    with pyro.plate("theta_plate", T):
        theta = pyro.sample("theta", Dirichlet(tau))

    # new topic linked to old topic
    with pyro.plate("gamma_plate", T_prev):
        gamma = pyro.sample("gamma", Dirichlet(prev_taus))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(phi))
        old = get_old_topics(which_old_topic)
        a = ((new_topic) * (linked))
        b = (1 - new_topic)
        c = ((new_topic) * (1 - linked))
        a = a[z].reshape(N, 1)
        b = b[z].reshape(N, 1)
        c = c[z].reshape(N, 1)
        mult_probs = 0
        mult_probs += a * gamma[old[z]] + b * prev_theta[old[z]] + c * theta[z]
def guide(data, args, batch_size=None):
    if debug: print("guide:")
    data = torch.reshape(data, [64, 1000])
    
    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        # lambda: torch.ones(8) / 8,
        torch.ones(8) / 8,
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # lambda: torch.ones(8, 1024) / 1024,
        torch.ones(8, 1024) / 1024,
        constraint=constraints.positive)
    """
    # wy: dummy param for word_topics
    word_topics_posterior = pyro.param(
        "word_topics_posterior",
        torch.ones(64, 1024, 8) / 8,
        constraint=constraints.positive)
    """
    
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior))
        if debug:
            print("topic_weights\t: shape={}, sum={}".
                  format(topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        if debug:
            print("counts.shape={}, counts_trans.shape={}".
                  format(counts.shape, torch.transpose(counts, 0, 1).shape))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        if debug:
            print("doc_topics_w(nn result)\t: shape={}, [0].sum={}".
                  format(doc_topics_w.shape, torch.sum(doc_topics_w[0])))
            d = Dirichlet(doc_topics_w)
            print("Dirichlet(doc_topics_w)\t: batch_shape={}, event_shape={}".
                    format(d.batch_shape, d.event_shape))
        # shape = [32] + [8]
        # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1))
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1))
        doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w))

        # wy: sample from exact posterior of word_topics
        with pyro.plate("words", 64):
            # ks : [K, D] = [8, 32]
            # ks = torch.arange(0,8).expand(32,8).transpose(0,1)
            ks = torch.arange(0, 8)
            ks = torch.Tensor.expand(ks, 32, 8)
            ks = torch.Tensor.transpose(ks, 0, 1)
            # logprob1 : [N, D, K] = [32, 8]
            # logprob1 = Categorical(doc_topics).log_prob(ks).transpose(0,1).expand(64,32,8)
            logprob1 = Categorical.log_prob(Categorical(doc_topics), ks)
            logprob1 = torch.Tensor.transpose(logprob1, 0, 1)
            logprob1 = torch.Tensor.expand(logprob1, 64, 32, 8)
            # data2 : [N, D, K] = [64, 32, 8]
            # data2 = data.expand(8,64,32).transpose(0,1).transpose(1,2)
            data2 = torch.Tensor.expand(data, 8, 64, 32)
            data2 = torch.Tensor.transpose(data2, 0, 1)
            data2 = torch.Tensor.transpose(data2, 1, 2)
            # logprob2 : [N, D, K] = [64, 32, 8]
            # logprob2 = Categorical(topic_words).log_prob(data2)
            logprob2 = Categorical.log_prob(Categorical(topic_words), data2)
            # prob : [N, D, K] = [64, 32, 8]
            prob = torch.exp(logprob1 + logprob2)
            # word_topics : [N, D] = [64, 32]
            word_topics = pyro.sample("word_topics", Categorical(prob))

        """
Ejemplo n.º 13
0
def guide(data, args, batch_size=None):
    if debug: print("guide:")
    data = torch.reshape(data, [64, 1000])

    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        # lambda: torch.ones(8) / 8,
        torch.ones(8) / 8,
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # lambda: torch.ones(8, 1024) / 1024,
        torch.ones(8, 1024) / 1024,
        constraint=constraints.positive)
    """
    # wy: dummy param for word_topics
    word_topics_posterior = pyro.param(
        "word_topics_posterior",
        torch.ones(64, 1024, 8) / 8,
        constraint=constraints.positive)
    """

    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights",
                                    Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(topic_words_posterior))
        if debug:
            print("topic_weights\t: shape={}, sum={}".format(
                topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        if debug:
            print("counts.shape={}, counts_trans.shape={}".format(
                counts.shape,
                torch.transpose(counts, 0, 1).shape))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        if debug:
            print("doc_topics_w(nn result)\t: shape={}, [0].sum={}".format(
                doc_topics_w.shape, torch.sum(doc_topics_w[0])))
        # shape = [32] + [8]
        # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1))
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1))
        doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w))
        """
Ejemplo n.º 14
0
def guide(data, args, batch_size=None):
    if debug: print("guide:")
    data = torch.reshape(data, [64, 1000])

    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param("topic_weights_posterior",
                                         torch.ones(8) / 8,
                                         constraint=constraints.positive)
    topic_words_posterior = pyro.param("topic_words_posterior",
                                       torch.ones(8, 1024) / 1024,
                                       constraint=constraints.positive)

    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights",
                                    Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words",
                                  Dirichlet(topic_words_posterior))
        if debug:
            print("topic_weights\t: shape={}, sum={}".format(
                topic_weights.shape, torch.sum(topic_weights)))
            print("topic_words\t: shape={}".format(topic_words.shape))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32, dim=-1) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        if debug:
            print("counts.shape={}, counts_trans.shape={}".format(
                counts.shape,
                torch.transpose(counts, 0, 1).shape))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics = sigmoid(layer3(h2))
        if debug:
            print("doc_topics(nn result)\t: shape={}, [0].sum={}".format(
                doc_topics.shape, torch.sum(doc_topics[0])))
            d = Dirichlet(doc_topics)
            print("Dirichlet(doc_topics)\t: batch_shape={}, event_shape={}".
                  format(d.batch_shape, d.event_shape))
        """
        NOTE: There are three options we can take:
        1. sample doc_topics (from *DELTA*) *ONLY*. (original code)
        2. sample doc_topics (from *DIRICHLET*) *ONLY*.
        3. sample doc_topics (from *DIRICHLET*), word_topics (from Categorical) *BOTH*.

        Expected results:
          | Trace_ELBO | TraceEnum_ELBO 
        -------------------------------
        1 | FAIL       | FAIL
        2 | FAIL       | PASS
        3 | PASS       | FAIL
        """
        # shape = [32] + [8]
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
        # doc_topics = pyro.sample("doc_topics", Delta(doc_topics).to_event(1))
        doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics))
        if debug:
            print("doc_topics(sampled)\t: shape={}, [0].sum = {}".format(
                doc_topics.shape, torch.sum(doc_topics[0])))
        with pyro.plate("words", 64, dim=-2):
            # shape = [64, 32, 8]
            word_topics_posterior = doc_topics * topic_words.transpose(
                0, 1)[data, :]
            word_topics_posterior = word_topics_posterior / (
                word_topics_posterior.sum(dim=-1, keepdim=True))
            word_topics =\
                pyro.sample("word_topics", Categorical(word_topics_posterior))
            if debug:
                print("word_tpics_posterior\t: shape={}".format(
                    word_topics_posterior.shape))
                print("word_topics(sampled)\t: shape={}".format(
                    word_topics.shape))
                d = Categorical(word_topics_posterior)
                print(
                    "Categorical(word_topics_posterior)\t: batch_shape={}, event_shape={}"
                    .format(d.batch_shape, d.event_shape))
Ejemplo n.º 15
0
def _kl_factorised_factorised(p: Factorised, q: Factorised):
    return sum(
        kl_divergence(p_factor, q_factor)
        for p_factor, q_factor in zip(p.factors, q.factors))


if __name__ == '__main__':
    from pyro.distributions import Dirichlet, MultivariateNormal
    from torch.distributions import kl_divergence
    from distributions.mixture import Mixture

    B, D1, D2 = 5, 3, 4
    N = 1000

    dist1 = MultivariateNormal(torch.zeros(D1), torch.eye(D1)).expand((B, ))
    dist2 = Dirichlet(torch.ones(D2)).expand((B, ))
    print(dist1.batch_shape, dist1.event_shape)
    print(dist2.batch_shape, dist2.event_shape)
    fact = Factorised([dist1, dist2])
    print(fact.batch_shape, fact.event_shape)
    samples = fact.rsample((N, ))
    print(samples[0])
    print(samples.shape)
    logp = fact.log_prob(samples)
    print(logp.shape)
    entropy = fact.entropy()
    print(entropy.shape)
    print(entropy, -logp.mean())
    print()

    print(kl_divergence(fact, fact))
Ejemplo n.º 16
0
        post_logits = self.mixing.logits + post_lognorm

        return NaturalMultivariateNormalMixture(
            Categorical(logits=post_logits), post_components)


def eval_grid(xx, yy, fcn):
    xy = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    return fcn(xy).reshape_as(xx)


if __name__ == '__main__':
    from pyro.distributions import Dirichlet

    N, K, D = 200, 4, 2
    props = Dirichlet(5 * torch.ones(K)).sample()
    mean = torch.arange(K).float().view(K, 1).expand(K, D)
    var = .1 * torch.eye(D).expand(K, -1, -1)
    mixing = Categorical(props)
    components = MultivariateNormal(mean, var)
    print("mixing", mixing.batch_shape, mixing.event_shape)
    print("components", components.batch_shape, components.event_shape)
    mixture = Mixture(mixing,
                      NaturalMultivariateNormal.from_standard(components))
    mixture.rename(['x', 'y'])
    print("mixture names", mixture.variable_names)
    print("mixture", mixture.batch_shape, mixture.event_shape)
    probe = MultivariateNormal(mean[:3] + 1 * torch.tensor([1., -1.]),
                               .2 * var[:3])
    post_mixture = mixture.posterior(probe)
    print("post_mixture names", post_mixture.variable_names)