def guide_t0(data): # T-1 alpha params for beta sampling kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T - 1]), constraint=constraints.positive) # concentration params for q_theta #[T,C] tau = pyro.param('tau', lambda: MultivariateNormal(0.5 * torch.ones(C), 0.25 * torch.eye(C)).sample([T]), constraint=constraints.unit_interval) # N params for categorical dist; topic weights; symmetric prior phi = pyro.param('phi', lambda: Dirichlet(1 / T * torch.ones(T)).sample([N]), constraint=constraints.simplex) with pyro.plate("beta_plate", T - 1): q_beta = 0 q_beta += pyro.sample("beta", Beta(torch.ones(T - 1), kappa)) # q_beta *= 1 # sample probs for multinomial distributions with pyro.plate("theta_plate", T): # outputs multinomial probabilities for each topic q_theta = 0 q_theta += pyro.sample("theta", Dirichlet(tau)) # q_theta *= 1 with pyro.plate("data", N): z = 0 z += pyro.sample("z", Categorical(phi))
def generate_model(data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", 8): topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.)) topic_words = pyro.sample("topic_words", Dirichlet(torch.ones(1024) / 1024)) # Locals. with pyro.plate("documents", 1000) as ind: if data is not None: data = data[:, ind] doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights)) with pyro.plate("words", 64): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. word_topics = pyro.sample("word_topics", Categorical(doc_topics), infer={"enumerate": "parallel"}) data = pyro.sample("doc_words", Categorical(topic_words[word_topics]), obs=data) return topic_weights, topic_words, data
def main(data=None, args=None, batch_size=None): data = torch.reshape(data, [64, 1000]) # Globals. with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(torch.ones(1024) / 1024)) # Locals. with pyro.plate("documents", 1000) as ind: # shape = [64, 32] data = data[:, ind] # shape = [32] + [8] doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights)) with pyro.plate("words", 64): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. # shape = [64, 32] + [] word_topics = pyro.sample("word_topics", Categorical(doc_topics), infer={"enumerate": "parallel"}) # shape = [64, 32] + [] data = pyro.sample("doc_words", Categorical(topic_words[word_topics]), obs=data)
def model(data=None, args=None, batch_size=None): if debug: print("model:") data = torch.reshape(data, [64, 1000]) # Globals. with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(torch.ones(1024) / 1024)) if debug: print("topic_weights\t: shape={}, sum={}".format( topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Locals. # with pyro.plate("documents", 1000) as ind: with pyro.plate("documents", 1000, 32, dim=-1) as ind: # if data is not None: # data = data[:, ind] # shape = [64, 32] data = data[:, ind] # shape = [32] + [8] doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights)) if debug: print("data\t\t: shape={}".format(data.shape)) print("doc_topics\t: shape={}, [0].sum={}".format( doc_topics.shape, torch.sum(doc_topics[0]))) # with pyro.plate("words", 64): with pyro.plate("words", 64, dim=-2): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. # shape = [64, 32] + [] word_topics =\ pyro.sample("word_topics", Categorical(doc_topics), infer={"enumerate": "parallel"}) # pyro.sample("word_topics", Categorical(doc_topics)) # shape = [64, 32] + [] data =\ pyro.sample("doc_words", Categorical(topic_words[word_topics]), obs=data) if debug: print("word_topics\t: shape={}".format(word_topics.shape)) print("data\t\t: shape={}".format(data.shape)) return topic_weights, topic_words, data
def model_t0(data): with pyro.plate("beta_plate", T - 1): beta = pyro.sample("beta", Beta(1, alpha)) with pyro.plate("theta_plate", T): # shape [T,C] # sample probabilities for Mult dist.; Dirichlet with symmetric prior theta = pyro.sample("theta", Dirichlet(torch.ones(C) / C)) with pyro.plate("data", N): # z==which topic z = pyro.sample("z", Categorical(probs=mix_weights(beta))) pyro.sample("obs", Multinomial(probs=theta[z]), obs=data)
def model(data): # whether new topic or not; prior=0.5; random choice whether old/new with pyro.plate("new_topic_plate", T): new_topic = pyro.sample("new_topic", Binomial(probs=0.5)) # if new topic, if linked to old topic, prior=0.5 with pyro.plate("linked_plate", T): linked = pyro.sample("linked", Binomial(probs=0.5)) # if old topic, which old topic with pyro.plate("old_topic_plate", T): which_old_topic = pyro.sample("which_old_topic", Multinomial(probs=prev_topic_freq)) # beta sampling for topic weights with pyro.plate("beta_plate", T - 1): beta = pyro.sample("beta", Beta(1, alpha)) with pyro.plate("theta_plate", T): # shape [T,C] # Dirichlet distribution (conjugate prior of Mult); symmetric prior theta = pyro.sample("theta", Dirichlet(torch.ones(C) / C)) with pyro.plate("gamma_plate", T_prev): gamma = pyro.sample("gamma", Dirichlet(prev_taus)) with pyro.plate("data", N): z = pyro.sample("z", Categorical(probs=mix_weights(beta))) old = get_old_topics(which_old_topic) a = (new_topic) * (linked) b = (1 - new_topic) c = (new_topic) * (1 - linked) a = a[z].reshape(N, 1) b = b[z].reshape(N, 1) c = c[z].reshape(N, 1) mult_probs = a * gamma[old[z]] + b * prev_theta[old[z]] + c * theta[z] pyro.sample("obs", Multinomial(probs=mult_probs), obs=data)
def main(data, args, batch_size=None): data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", torch.ones(8), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # WL: edited. ===== #torch.ones(8, 1024), #constraint=constraints.greater_than(0.5) torch.ones(8, 1024) * 0.5, constraint=constraints.positive # ================= ) with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] # WL: edited. ===== #topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior + 0.5)) # ================= # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) # shape = [32] + [8] doc_topics = softmax(doc_topics_w) pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
def guide(sequences): theta = pyro.param("theta", torch.ones(16)) alpha = pyro.param("alpha", torch.rand(1)) beta = pyro.param("beta", torch.rand(1)) p = pyro.param("p", torch.rand(1)) q = pyro.param("q", torch.rand(1)) w = p * torch.eye(16) + q with poutine.mask(mask=False): probs_x = pyro.sample("probs_x", Dirichlet(w).to_event(1)) probs_y = pyro.sample("probs_y", Beta(alpha, beta).expand([16, 51]).to_event(2)) for i in pyro.plate("sequences", len(sequences), 8): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): x = pyro.sample("x_{}_{}".format(i, t), Categorical(probs_x[x]))
def model(sequences): with poutine.mask(mask=False): probs_x = pyro.sample("probs_x", Dirichlet(0.9 * torch.eye(16) + 0.1).to_event(1)) probs_y = pyro.sample("probs_y", Beta(0.1, 0.9).expand([16, 51]).to_event(2)) tones_plate = pyro.plate("tones", 51, dim=-1) for i in pyro.plate("sequences", len(sequences)): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): x = pyro.sample("x_{}_{}".format(i, t), Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def naive_acceptability_and_central_weight(number_criterion, number_alternatives, number_iterations): central_weight_vector = torch.zeros( [number_alternatives, number_criterion]) # at i,j for alternative i at coordinate j weight_shape = torch.ones([ number_criterion, ]) count_matrix = torch.zeros([number_alternatives, number_alternatives ]) # at i,j for alternative i ranked j-th for _ in range(number_iterations): weights = Dirichlet(weight_shape).sample() crit_alt_mat = crit_alt_matrix( number_alternatives, number_criterion) # at i,j for alternative i against criterion j rank_vector = [ rank(i, crit_alt_mat, weights) for i in range(number_alternatives) ] # best rank is 0 for i in range(number_alternatives): count_matrix[i, rank_vector[i]] += 1 if rank_vector[i] == 0: central_weight_vector[i] += weights acceptability_index = torch.zeros_like( count_matrix ) # at i,j for approx proba of alternative i should be ranked j-th for i in range(number_alternatives): if count_matrix[i, 0] > 0: central_weight_vector[i] /= count_matrix[ i, 0] # average vector ranking alternative i on top for j in range(number_alternatives): acceptability_index[i, j] = count_matrix[ i, j] / number_iterations # approx proba for alternative i should be on top return central_weight_vector, acceptability_index
def guide(data): # pyro params new_topic_prob = pyro.param("new_topic_prob", lambda: Uniform(0, 1).sample([T]), constraint=constraints.unit_interval) linked_prob = pyro.param("linked_prob", lambda: Uniform(0, 1).sample([T]), constraint=constraints.unit_interval) which_topic_probs = pyro.param("which_topic_probs", lambda: Uniform(0, 1).sample([T_prev]), constraint=constraints.simplex) kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T - 1]), constraint=constraints.positive) tau = pyro.param('tau', lambda: MultivariateNormal(0.5 * torch.ones(C), 0.25 * torch.eye(C)).sample([T]), constraint=constraints.unit_interval) # N params for categorical dist; topic weights; symmetric prior phi = pyro.param('phi', lambda: Dirichlet(1 / T * torch.ones(T)).sample([N]), constraint=constraints.simplex) # model params with pyro.plate("new_topic_plate", T): # print(new_topic_prob) new_topic = pyro.sample("new_topic", Binomial(probs=new_topic_prob)) # if new topic, if linked to old topic, prior=0.5 with pyro.plate("linked_plate", T): linked = pyro.sample("linked", Binomial(probs=linked_prob)) # if old topic, which old topic with pyro.plate("old_topic_plate", T): which_old_topic = pyro.sample("which_old_topic", Multinomial(probs=which_topic_probs)) with pyro.plate("beta_plate", T - 1): q_beta = 0 q_beta += pyro.sample("beta", Beta(torch.ones(T - 1), kappa)) # new topic with symmetric prior with pyro.plate("theta_plate", T): theta = pyro.sample("theta", Dirichlet(tau)) # new topic linked to old topic with pyro.plate("gamma_plate", T_prev): gamma = pyro.sample("gamma", Dirichlet(prev_taus)) with pyro.plate("data", N): z = pyro.sample("z", Categorical(phi)) old = get_old_topics(which_old_topic) a = ((new_topic) * (linked)) b = (1 - new_topic) c = ((new_topic) * (1 - linked)) a = a[z].reshape(N, 1) b = b[z].reshape(N, 1) c = c[z].reshape(N, 1) mult_probs = 0 mult_probs += a * gamma[old[z]] + b * prev_theta[old[z]] + c * theta[z]
def guide(data, args, batch_size=None): if debug: print("guide:") data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", # lambda: torch.ones(8) / 8, torch.ones(8) / 8, constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # lambda: torch.ones(8, 1024) / 1024, torch.ones(8, 1024) / 1024, constraint=constraints.positive) """ # wy: dummy param for word_topics word_topics_posterior = pyro.param( "word_topics_posterior", torch.ones(64, 1024, 8) / 8, constraint=constraints.positive) """ with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) if debug: print("topic_weights\t: shape={}, sum={}". format(topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) if debug: print("counts.shape={}, counts_trans.shape={}". format(counts.shape, torch.transpose(counts, 0, 1).shape)) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) if debug: print("doc_topics_w(nn result)\t: shape={}, [0].sum={}". format(doc_topics_w.shape, torch.sum(doc_topics_w[0]))) d = Dirichlet(doc_topics_w) print("Dirichlet(doc_topics_w)\t: batch_shape={}, event_shape={}". format(d.batch_shape, d.event_shape)) # shape = [32] + [8] # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1)) # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1)) doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w)) # wy: sample from exact posterior of word_topics with pyro.plate("words", 64): # ks : [K, D] = [8, 32] # ks = torch.arange(0,8).expand(32,8).transpose(0,1) ks = torch.arange(0, 8) ks = torch.Tensor.expand(ks, 32, 8) ks = torch.Tensor.transpose(ks, 0, 1) # logprob1 : [N, D, K] = [32, 8] # logprob1 = Categorical(doc_topics).log_prob(ks).transpose(0,1).expand(64,32,8) logprob1 = Categorical.log_prob(Categorical(doc_topics), ks) logprob1 = torch.Tensor.transpose(logprob1, 0, 1) logprob1 = torch.Tensor.expand(logprob1, 64, 32, 8) # data2 : [N, D, K] = [64, 32, 8] # data2 = data.expand(8,64,32).transpose(0,1).transpose(1,2) data2 = torch.Tensor.expand(data, 8, 64, 32) data2 = torch.Tensor.transpose(data2, 0, 1) data2 = torch.Tensor.transpose(data2, 1, 2) # logprob2 : [N, D, K] = [64, 32, 8] # logprob2 = Categorical(topic_words).log_prob(data2) logprob2 = Categorical.log_prob(Categorical(topic_words), data2) # prob : [N, D, K] = [64, 32, 8] prob = torch.exp(logprob1 + logprob2) # word_topics : [N, D] = [64, 32] word_topics = pyro.sample("word_topics", Categorical(prob)) """
def guide(data, args, batch_size=None): if debug: print("guide:") data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", # lambda: torch.ones(8) / 8, torch.ones(8) / 8, constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # lambda: torch.ones(8, 1024) / 1024, torch.ones(8, 1024) / 1024, constraint=constraints.positive) """ # wy: dummy param for word_topics word_topics_posterior = pyro.param( "word_topics_posterior", torch.ones(64, 1024, 8) / 8, constraint=constraints.positive) """ with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) if debug: print("topic_weights\t: shape={}, sum={}".format( topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) if debug: print("counts.shape={}, counts_trans.shape={}".format( counts.shape, torch.transpose(counts, 0, 1).shape)) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) if debug: print("doc_topics_w(nn result)\t: shape={}, [0].sum={}".format( doc_topics_w.shape, torch.sum(doc_topics_w[0]))) # shape = [32] + [8] # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1)) # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1)) doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w)) """
def guide(data, args, batch_size=None): if debug: print("guide:") data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param("topic_weights_posterior", torch.ones(8) / 8, constraint=constraints.positive) topic_words_posterior = pyro.param("topic_words_posterior", torch.ones(8, 1024) / 1024, constraint=constraints.positive) with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) if debug: print("topic_weights\t: shape={}, sum={}".format( topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32, dim=-1) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) if debug: print("counts.shape={}, counts_trans.shape={}".format( counts.shape, torch.transpose(counts, 0, 1).shape)) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics = sigmoid(layer3(h2)) if debug: print("doc_topics(nn result)\t: shape={}, [0].sum={}".format( doc_topics.shape, torch.sum(doc_topics[0]))) d = Dirichlet(doc_topics) print("Dirichlet(doc_topics)\t: batch_shape={}, event_shape={}". format(d.batch_shape, d.event_shape)) """ NOTE: There are three options we can take: 1. sample doc_topics (from *DELTA*) *ONLY*. (original code) 2. sample doc_topics (from *DIRICHLET*) *ONLY*. 3. sample doc_topics (from *DIRICHLET*), word_topics (from Categorical) *BOTH*. Expected results: | Trace_ELBO | TraceEnum_ELBO ------------------------------- 1 | FAIL | FAIL 2 | FAIL | PASS 3 | PASS | FAIL """ # shape = [32] + [8] # doc_topics = pyro.sample("doc_topics", Delta(doc_topics, event_dim=1)) # doc_topics = pyro.sample("doc_topics", Delta(doc_topics).to_event(1)) doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics)) if debug: print("doc_topics(sampled)\t: shape={}, [0].sum = {}".format( doc_topics.shape, torch.sum(doc_topics[0]))) with pyro.plate("words", 64, dim=-2): # shape = [64, 32, 8] word_topics_posterior = doc_topics * topic_words.transpose( 0, 1)[data, :] word_topics_posterior = word_topics_posterior / ( word_topics_posterior.sum(dim=-1, keepdim=True)) word_topics =\ pyro.sample("word_topics", Categorical(word_topics_posterior)) if debug: print("word_tpics_posterior\t: shape={}".format( word_topics_posterior.shape)) print("word_topics(sampled)\t: shape={}".format( word_topics.shape)) d = Categorical(word_topics_posterior) print( "Categorical(word_topics_posterior)\t: batch_shape={}, event_shape={}" .format(d.batch_shape, d.event_shape))
def _kl_factorised_factorised(p: Factorised, q: Factorised): return sum( kl_divergence(p_factor, q_factor) for p_factor, q_factor in zip(p.factors, q.factors)) if __name__ == '__main__': from pyro.distributions import Dirichlet, MultivariateNormal from torch.distributions import kl_divergence from distributions.mixture import Mixture B, D1, D2 = 5, 3, 4 N = 1000 dist1 = MultivariateNormal(torch.zeros(D1), torch.eye(D1)).expand((B, )) dist2 = Dirichlet(torch.ones(D2)).expand((B, )) print(dist1.batch_shape, dist1.event_shape) print(dist2.batch_shape, dist2.event_shape) fact = Factorised([dist1, dist2]) print(fact.batch_shape, fact.event_shape) samples = fact.rsample((N, )) print(samples[0]) print(samples.shape) logp = fact.log_prob(samples) print(logp.shape) entropy = fact.entropy() print(entropy.shape) print(entropy, -logp.mean()) print() print(kl_divergence(fact, fact))
post_logits = self.mixing.logits + post_lognorm return NaturalMultivariateNormalMixture( Categorical(logits=post_logits), post_components) def eval_grid(xx, yy, fcn): xy = torch.stack([xx.flatten(), yy.flatten()], dim=1) return fcn(xy).reshape_as(xx) if __name__ == '__main__': from pyro.distributions import Dirichlet N, K, D = 200, 4, 2 props = Dirichlet(5 * torch.ones(K)).sample() mean = torch.arange(K).float().view(K, 1).expand(K, D) var = .1 * torch.eye(D).expand(K, -1, -1) mixing = Categorical(props) components = MultivariateNormal(mean, var) print("mixing", mixing.batch_shape, mixing.event_shape) print("components", components.batch_shape, components.event_shape) mixture = Mixture(mixing, NaturalMultivariateNormal.from_standard(components)) mixture.rename(['x', 'y']) print("mixture names", mixture.variable_names) print("mixture", mixture.batch_shape, mixture.event_shape) probe = MultivariateNormal(mean[:3] + 1 * torch.tensor([1., -1.]), .2 * var[:3]) post_mixture = mixture.posterior(probe) print("post_mixture names", post_mixture.variable_names)