def model(self,*args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params['batch_size'] weights = pyro.sample('mixture_weights', dist.Dirichlet((1 / self._params['K']) * torch.ones(self._params['K']))) cat_vector = torch.tensor(np.arange(self._params['hidden_dim']) + 1, dtype = torch.float) with pyro.plate('segments', I): segment_factor = pyro.sample('segment_factor', dist.Gamma(self._params['theta_scale'], self._params['theta_rate'])) with pyro.plate('components', self._params['K']): cc = pyro.sample("CNV_probabilities", dist.Dirichlet(self.create_dirichlet_init_values())) with pyro.plate('data', N, batch): # p(x|z_i) = Poisson(marg(cc * theta * segment_factor)) segment_fact_cat = torch.matmul(segment_factor.reshape([I,1]) , cat_vector.reshape([1, self._params['hidden_dim']])) segment_fact_marg = segment_fact_cat * cc segment_fact_marg = torch.sum(segment_fact_marg, dim = -1) # p(z_i| D, X ) = lk(z_i) * p(z_i | X) / sum_z_i(lk(z_i) * p(z_i | X)) # log(p(z_i| D, X )) = log(lk(z_i)) + log(p(z_i | X)) - log_sum_exp(log(lk(z_i)) + log(p(z_i | X))) pyro.factor("lk", self.likelihood(segment_fact_marg, weights, self._params['theta']))
def model(transition_alphas, emission_alphas, lengths, sequences=None, batch_size=None): # From https://pyro.ai/examples/hmm.html with ignore_jit_warnings(): if sequences is not None: num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length else: data_dim = emission_alphas.size(1) num_sequences = int(lengths.shape[0]) max_length = int(lengths.max()) transition_probs = pyro.sample('transition_probs', dist.Dirichlet(transition_alphas).to_event(1)) emission_probs = pyro.sample('emission_probs', dist.Dirichlet(emission_alphas).to_event(2)) element_plate = pyro.plate('elements', data_dim, dim=-1) with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] state = 0 for t in pyro.markov(range(max_length)): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]), infer={'enumerate': 'parallel'}) obs_element = Vindex(sequences)[batch, t] if sequences is not None else None with element_plate: element = pyro.sample(f'element_{t}', dist.Categorical(emission_probs[state.squeeze(-1)]), obs=obs_element)
def model(self, *args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params['batch_size'] weights = pyro.sample( 'mixture_weights', dist.Dirichlet( (1 / self._params['K']) * torch.ones(self._params['K']))) with pyro.plate('segments', I): with pyro.plate('components', self._params['K']): cnv_probs = pyro.sample( "cnv_probs", dist.Dirichlet(self._params['probs'] * 1 / torch.ones(self._params['hidden_dim']))) with pyro.plate("data2", N, batch): theta = pyro.sample( 'norm_factor', dist.Gamma(self._params['theta_scale'], self._params['theta_rate'])) with pyro.plate('data', N, batch): assignment = pyro.sample('assignment', dist.Categorical(weights), infer={"enumerate": "parallel"}) for i in pyro.plate('segments2', I): cc = pyro.sample('copy_number_{}'.format(i), dist.Categorical( Vindex(cnv_probs)[assignment, i, :]), infer={"enumerate": "parallel"}) pyro.sample('obs_{}'.format(i), dist.Poisson((cc * theta * self._data['mu'][i]) + 1e-8), obs=self._data['data'][i, :])
def model(data, batch_size=32): alpha = torch.ones(num_topics) alpha = alpha.to(device) eta = torch.ones(num_words) eta = eta.to(device) with pyro.plate("topic_loop", num_topics): #beta =[num_topics, num_words] beta = pyro.sample("beta", dist.Dirichlet(eta)) beta = beta.to(device) #print(beta) #print(beta.shape) with pyro.plate("document_loop", num_documents) as ind: #theta = num_documents * num_topics theta = pyro.sample("theta", dist.Dirichlet(alpha)) theta = theta.to(device) #print(theta.shape) #print(theta) with pyro.plate("word_loop", words_per_doc): # z = [num_words, num_documents] z = pyro.sample("z", dist.Categorical(theta)) z = z.to(device) #print(z.shape) #print(z) results = pyro.sample("obs", dist.Categorical(Vindex(beta)[z]), obs=data[:, ind]) #print(results.shape) return results
def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x with poutine.mask(mask=include_prior): probs_w = pyro.sample("probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) .to_event(1)) probs_x = pyro.sample("probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) .expand_by([hidden_dim]) .to_event(2)) probs_y = pyro.sample("probs_y", dist.Beta(0.1, 0.9) .expand([hidden_dim, hidden_dim, data_dim]) .to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model(self, docs=None, doc_sum=None): # Globals. with pyro.plate("topics", self.num_topics): a = torch.tensor(1. / self.num_topics, device=self.device) b = torch.tensor(1., device=self.device) topic_weights = pyro.sample("topic_weights", dist.Gamma(a, b)) alpha = torch.ones(self.vocab_size, device=self.device) / self.vocab_size topic_words = pyro.sample("topic_words", dist.Dirichlet(alpha)) # Locals. # We will use nested plates. Pyro convention is to count from the right # by using negative indices like -1, -2. This means documents must be at # the rightmost dimension, followed by words. For this reason, we transpose # the data: docs = docs.transpose(0, 1) with pyro.plate('documents', docs.shape[-1]): doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights)) with pyro.plate("words", docs.shape[-2]): word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), infer={"enumerate": "parallel"}) data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=docs) return topic_words
def model(data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", args.num_topics): topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample("topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) # Locals. with pyro.plate("documents", args.num_docs) as ind: if data is not None: with pyro.util.ignore_jit_warnings(): assert data.shape == (args.num_words_per_doc, args.num_docs) data = data[:, ind] doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights)) with pyro.plate("words", args.num_words_per_doc): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), infer={"enumerate": "parallel"}) data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data) return topic_weights, topic_words, data
def guide_ret(*args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params[ 'batch_size'] param_weights = pyro.param( "param_weights", lambda: torch.ones(self._params['K']) / self._params['K'], constraint=constraints.simplex) hidden_vals = pyro.param( "param_hidden_weights", lambda: self.create_dirichlet_init_values(), constraint=constraints.simplex) gamma_scale = pyro.param( "param_gamma_scale", lambda: torch.mean( self._data['data'] / (2 * self._data['mu'].reshape( self._data['data'].shape[0], 1)), axis=0) * self._params['gamma_multiplier'], constraint=constraints.positive) gamma_rate = pyro.param( "param_rate", lambda: torch.ones(1) * self._params['gamma_multiplier'], constraint=constraints.positive) weights = pyro.sample('mixture_weights', dist.Dirichlet(param_weights)) with pyro.plate('segments', I): with pyro.plate('components', self._params['K']): pyro.sample("cnv_probs", dist.Dirichlet(hidden_vals)) with pyro.plate("data2", N, batch): pyro.sample('norm_factor', dist.Gamma(gamma_scale, gamma_rate))
def model(doc_word_data=None, category_data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", args.num_topics): # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of # the importance of topics. It might be from the amortized LDA paper. topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample( "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) with pyro.plate("categories", args.num_categories): category_weights = pyro.sample( "category_weights", dist.Gamma(1. / args.num_categories, 1.)) # TODO category weights might not be necessary in our model category_topics = pyro.sample("category_topics", dist.Dirichlet(topic_weights)) doc_category_list = [] doc_word_list = [] # Locals. for index, doc in enumerate(pyro.plate("documents", args.num_docs)): if doc_word_data is not None: cur_doc_word_data = doc_word_data[doc] else: cur_doc_word_data = None if category_data is not None: cur_category_data = category_data[doc] else: cur_category_data = None doc_category_list.append( pyro.sample("doc_categories_{}".format(doc), dist.Categorical(category_weights), obs=cur_category_data)) with pyro.plate("words_{}".format(doc), args.num_words_per_doc[doc]): word_topics = pyro.sample( "word_topics_{}".format(doc), dist.Categorical(category_topics[int( doc_category_list[index].item())])) # TODO Enum parallel/sequential optimizing? doc_word_list.append( pyro.sample("doc_words_{}".format(doc), dist.Categorical(topic_words[word_topics]), obs=cur_doc_word_data)) results = { "topic_weights": topic_weights, "topic_words": topic_words, "doc_word_data": doc_word_list, "category_weights": category_weights, "category_topics": category_topics, "doc_category_data": doc_category_list } return results
def parametrized_guide(doc_word_data, category_data, args, batch_size=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param("topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) category_weights_posterior = pyro.param( "category_weights_posterior", lambda: torch.ones(args.num_categories), constraint=constraints.positive) category_topics_posterior = pyro.param( "category_topics_posterior", lambda: torch.ones(args.num_categories, args.num_topics), constraint=constraints.greater_than(0.5)) with pyro.plate("categories", args.num_categories): pyro.sample("category_weights", dist.Gamma(category_weights_posterior, 1.)) pyro.sample("category_topics", dist.Dirichlet(category_topics_posterior)) doc_category_posterior = pyro.param("doc_category_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) with pyro.plate("documents", args.num_docs, batch_size) as ind: pyro.sample("doc_categories", dist.Categorical(doc_category_posterior))
def model(self, *args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params['batch_size'] weights = pyro.sample('mixture_weights', dist.Dirichlet(torch.ones(self._params['K']))) with pyro.plate('components', self._params['K']): probs_z = pyro.sample( "cnv_probs", dist.Dirichlet(self._params['t'] * torch.eye(self._params['hidden_dim']) + (1 - self._params['t'])).to_event(1)) with pyro.plate("data2", N, batch): theta = pyro.sample( 'norm_factor', dist.Gamma(self._params['theta_scale'], self._params['theta_rate'])) with pyro.plate('data', N, batch): z = 0 assignment = pyro.sample('assignment', dist.Categorical(weights), infer={"enumerate": "parallel"}) for i in pyro.markov(range(I)): z = pyro.sample("z_{}".format(i), dist.Categorical( Vindex(probs_z)[assignment, z]), infer={"enumerate": "parallel"}) pyro.sample('obs_{}'.format(i), dist.Poisson((z * theta * self._data['mu'][i]) + 1e-8), obs=self._data['data'][i, :])
def model_3(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with poutine.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] w, x = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model(self, *args, **kwargs): I = self._data['segments'] pi = pyro.sample("pi", dist.Dirichlet(self._params['init_probs'])) probs_z = pyro.sample("cnv_probs", dist.Dirichlet((1- self._params['t']) * torch.eye(self._params['hidden_dim']) + ( self._params['t'])).to_event(1)) probs_y = torch.tensor([[2., 64., 32., 21.5, 16., 43.],[64., 64., 64., 64., 64., 64.]]) z = pyro.sample("z_0", dist.Categorical(pi), infer={"enumerate": "parallel"}) pyro.sample("y_{}".format(0), dist.Beta(probs_y[0, z], probs_y[1, z]), obs=self._data['data'][0, 0]) for i in pyro.markov(range(1,I)): z = pyro.sample("z_{}".format(i), dist.Categorical(Vindex(probs_z)[z]), infer={"enumerate": "parallel"}) pyro.sample("y_{}".format(i), dist.Beta(probs_y[0,z], probs_y[1,z]), obs= self._data['data'][i,0])
def model(K=None, M=None, N=None, V=None, alpha=None, beta=None, doc=None, w=None): theta = sample('theta', ImproperUniform(shape=(M, K))) phi = sample('phi', ImproperUniform(shape=(K, V))) for m in range(1, M + 1): sample('theta' + '__{}'.format(m - 1) + '__1', dist.Dirichlet(alpha), obs=theta[m - 1]) for k in range(1, K + 1): sample('phi' + '__{}'.format(k - 1) + '__2', dist.Dirichlet(beta), obs=phi[k - 1]) for n in range(1, N + 1): gamma = zeros(K) for k in range(1, K + 1): gamma[k - 1] = log(theta[doc[n - 1] - 1, k - 1]) + log( phi[k - 1, w[n - 1] - 1]) sample('expr' + '__{}'.format(n) + '__3', dist.Exponential(1.0), obs=-log_sum_exp(gamma))
def guide(self, data): alpha_posterior = pyro.param("alpha_posterior", lambda: torch.ones(self.n_topics), constraint=positive) beta_posterior = pyro.param( "beta_posterior", lambda: torch.ones(self.n_topics, self.vocab_size), constraint=greater_than(0.5)) with pyro.plate("topics", self.n_topics): alpha = pyro.sample("alpha", dist.Gamma(alpha_posterior, 1.)) betas = pyro.sample("beta", dist.Dirichlet(beta_posterior)) theta = None z = None for d in pyro.plate("doc_loop", len(data)): gamma_q = pyro.param(f"gamma_{d}", torch.ones(self.n_topics), constraint=positive) theta = pyro.sample(f"theta_{d}", dist.Dirichlet(gamma_q)) nwords = len(data[d]) for w in pyro.plate(f"word_loop_{d}", nwords): phi_q = pyro.param(f"phi{d}_{w}", torch.ones(self.n_topics), constraint=positive) z = pyro.sample(f"z{d}_{w}", dist.Categorical(phi_q)) return theta, z, alpha, betas
def model(data): initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim))) with pyro.plate("states", dim): transition = pyro.sample("transition", dist.Dirichlet(torch.ones(dim, dim))) emission_loc = pyro.sample( "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim))) emission_scale = pyro.sample( "emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim))) x = None with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning) ]): for t, y in pyro.markov(enumerate(data)): x = pyro.sample( "x_{}".format(t), dist.Categorical( initialize if x is None else transition[x]), infer={"enumerate": "parallel"}, ) pyro.sample( "y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y, )
def model(doc_word_data=None, category_data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", args.num_topics): # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of # the importance of topics. It might be from the amortized LDA paper. topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample( "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) with pyro.plate("categories", args.num_categories): category_weights = pyro.sample( "category_weights", dist.Gamma(1. / args.num_categories, 1.)) category_topics = pyro.sample("category_topics", dist.Dirichlet(topic_weights)) # Locals. with pyro.plate("documents", args.num_docs) as ind: if doc_word_data is not None: with pyro.util.ignore_jit_warnings(): assert doc_word_data.shape == (args.num_words_per_doc, args.num_docs ) # Forces the 64x1000 shape doc_word_data = doc_word_data[:, ind] if category_data is not None: category_data = category_data[ind] category_data = pyro.sample("doc_categories", dist.Categorical(category_weights), obs=category_data) with pyro.plate("words", args.num_words_per_doc): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. word_topics = pyro.sample("word_topics", dist.Categorical( category_topics[category_data]), infer={"enumerate": "parallel"}) doc_word_data = pyro.sample("doc_words", dist.Categorical( topic_words[word_topics]), obs=doc_word_data) results = { "topic_weights": topic_weights, "topic_words": topic_words, "doc_word_data": doc_word_data, "category_weights": category_weights, "category_topics": category_topics, "category_data": category_data } return results
def guide(data): qalpha0 = pyro.param("qalpha0", torch.ones(nd, nz, 1, Td), constraint=constraints.positive) qalpha1 = pyro.param("qalpha1", torch.ones(nz, 1, nw, ntr), constraint=constraints.positive) # CHANGE: use the fact that dirichlet can draw independant dirichlets pyro.sample("latent0", pdist.Dirichlet(concentration=qalpha0.view(nd, -1))) pyro.sample("latent1", pdist.Dirichlet(concentration=qalpha1.view(nz, -1)))
def model(self, enc_in, dec_in, dec_out, T, N): pyro.module("AE", self.AE) theta_t = pyro.sample('theta_t', dist.Dirichlet(torch.ones(K) * 10)) theta_d = pyro.sample('theta_d', dist.Dirichlet(torch.ones(K) * 10)) with pyro.iarange('data.loop', N, dim=-1) as i: z_t = pyro.sample('z_t', dist.Categorical(theta_t.expand(N, K))) z_d = pyro.sample('z_d', dist.Categorical(theta_d.expand(N, K))) pi = self.AE.decode([z_t, z_d, dec_in[i]]) for t in range(T - 1): pyro.sample('y_{}_{}'.format(i, t), dist.Categorical(pi[:, t, :]), obs=dec_out[i, t])
def test_dirichlet_multinomial(sample_shape, batch_shape): concentration = torch.randn(batch_shape + (3,)).exp() total = 10 probs = torch.tensor([0.2, 0.3, 0.5]) obs = dist.Multinomial(total, probs).sample(sample_shape + batch_shape) f = dist.Dirichlet(concentration) g = dist.Dirichlet(1 + obs) fg, log_normalizer = f.conjugate_update(g) x = fg.sample(sample_shape) assert_close(f.log_prob(x) + g.log_prob(x), fg.log_prob(x) + log_normalizer)
def unsupervised_hmm(words): with pyro.plate("prob_plate", num_categories): transition_prob = pyro.sample("transition_prob", dist.Dirichlet(transition_prior)) emission_prob = pyro.sample("emission_prob", dist.Dirichlet(emission_prior)) transition_log_prob = transition_prob.log() emission_log_prob = emission_prob.log() log_prob = emission_log_prob[:, words[0]] for t in range(1, len(words)): log_prob = forward_log_prob(log_prob, words[t], transition_log_prob, emission_log_prob) prob = log_prob.logsumexp(dim=0).exp() # a trick to inject an additional log_prob into model's log_prob pyro.sample("forward_prob", dist.Bernoulli(prob), obs=torch.tensor(1.))
def model(data): s0 = (nd, nz, 1, Td) s1 = (nz, 1, nw, ntr) alpha0 = torch.ones(*s0).cpu() alpha1 = torch.ones(*s1).cpu() z = pyro.sample("latent0", pdist.Dirichlet(concentration=alpha0.view(nd, -1))) motifs = pyro.sample("latent1", pdist.Dirichlet(concentration=alpha1.view(nz, -1))) z = z.reshape(*s0) motifs = motifs.reshape(*s1) p = p_w_ta_d(z, motifs) with pyro.iarange("data", len(data)): zts = pyro.sample("zts", pdist.Categorical(probs=z)) pyro.sample("observe", pdist.Multinomial(probs=p), obs=data)
def model(data=None, num_words_per_doc=None, args=None): # Globals. with pyro.plate("topics", args.num_topics): topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample("topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) # Changed here to from vector(with) to iteration to support varying number # of words (num_words_per_doc). # with pyro.plate("documents", args.num_docs) as ind: for doc in pyro.plate("documents", args.num_docs): doc_topics = pyro.sample("doc_topics_{}".format(doc), dist.Dirichlet(topic_weights)) with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]): word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics)) pyro.sample("doc_words_{}".format(doc), dist.Categorical(topic_words[word_topics]), obs=data[doc]) return topic_weights, topic_words
def guide(data): qalpha0 = pyro.param("qalpha0", torch.ones(nd, nz, 1, Td).cpu(), constraint=constraints.positive) # z_ts table global step_motif_count if flag_ISM: qalpha1 = pyro.param("qalpha1", init_motif, constraint=constraints.positive) # motif if step_motif_count % 5 == 0: tem_motif.append(qalpha1) else: qalpha1 = pyro.param("qalpha1", torch.ones(nz, 1, nw, ntr).cpu(), constraint=constraints.positive) # motif if step_motif_count % 5 == 0: tem_motif.append(qalpha1) # CHANGE: use the fact that dirichlet can draw independant dirichlets pyro.sample("latent0", pdist.Dirichlet(concentration=qalpha0.view(nd, -1))) pyro.sample("latent1", pdist.Dirichlet(concentration=qalpha1.view(nz, -1)))
def model_0(sequences, lengths, args, batch_size=None, include_prior=True): assert not torch._C._get_tracing_state() num_sequences, max_length, data_dim = sequences.shape with poutine.mask(mask=include_prior): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) for i in pyro.plate("sequences", len(sequences), batch_size): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def model_5(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): y = pyro.sample( "y_{}".format(t), dist.Bernoulli(logits=tones_generator(x, y)), obs=sequences[batch, t])
def model_2(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), obs=sequences[batch, t]).long()
def model_7(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch: lengths = lengths[batch] y = sequences[batch] if args.jit else sequences[batch, :lengths.max()] x = torch.arange(args.hidden_dim) t = torch.arange(y.size(1)) init_logits = torch.full((args.hidden_dim, ), -float("inf")) init_logits[0] = 0 trans_logits = probs_x.log() with ignore_jit_warnings(): obs_dist = dist.Bernoulli( logits=tones_generator(x, y.unsqueeze(-2))).to_event(1) obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1)) hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) pyro.sample("y", hmm_dist, obs=y)
def test_dirichlet_shape(): alpha = ng_ones(3, 2) / 2 d = dist.Dirichlet(alpha) assert d.batch_shape() == (3, ) assert d.event_shape() == (2, ) assert d.shape() == (3, 2) assert d.sample().size() == d.shape()
def one_hot_model(pseudocounts, classes=None): probs_prior = dist.Dirichlet(pseudocounts) probs = pyro.sample("probs", probs_prior) with pyro.plate("classes", classes.size(0) if classes is not None else 1, dim=-1): return pyro.sample("obs", dist.OneHotCategorical(probs), obs=classes)