def model_2(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), obs=sequences[batch, t]).long()
def model_0(sequences, lengths, args, batch_size=None, include_prior=True): assert not torch._C._get_tracing_state() num_sequences, max_length, data_dim = sequences.shape with poutine.mask(mask=include_prior): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) for i in pyro.plate("sequences", len(sequences), batch_size): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def model_5(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): y = pyro.sample( "y_{}".format(t), dist.Bernoulli(logits=tones_generator(x, y)), obs=sequences[batch, t])
def model(self, x, y): """ Generative model for the data. """ pyro.module("MDN", self) pi, sigma, mu = self.forward(y) muT = torch.transpose(mu, 0, 1) sigmaT = torch.transpose(sigma, 0, 1) n_samples = y.shape[0] assert muT.shape == (n_gaussians, n_samples) assert sigmaT.shape == (n_gaussians, n_samples) with pyro.plate("samples", n_samples): assign = pyro.sample("assign", dist.Categorical(pi)) # We need this case distinction for the two different # cases of assignment: sampling a random assignment and # enumerating over mixtures. See # http://pyro.ai/examples/enumeration.html for a tutorial. if len(assign.shape) == 1: sample = pyro.sample('obs', dist.Normal( torch.gather(muT, 0, assign.view(1, -1))[0], torch.gather(sigmaT, 0, assign.view(1, -1))[0]), obs=x) else: sample = pyro.sample('obs', dist.Normal(muT[assign][:, 0], sigmaT[assign][:, 0]), obs=x) return sample
def intervention(): prob_A = torch.tensor([ .50, # P(A = 'on') .50 # P(A = 'off') ]) prob_B = torch.tensor([ [ .90, # P(B = 'on' | A = 'on') .10 # P(B = 'off' | A = 'on') ], [ .20, # P(B = 'on' | A = 'off') .80 # P(B = 'off' | A = 'off') ] ]) prob_C = torch.tensor([ [ [ .60, # P(C = 'on' | A = 'on', B = 'on') .40 # P(C = 'off' | A = 'on', B = 'on') ], [ .01, # P(C = 'on' | A = 'on', B = 'off') .99 # P(C = 'off' | A = 'on', B = 'off') ] ], [ [ .90, # P(C = 'on' | A = 'off', B = 'on') .10 # P(C = 'off' | A = 'off', B = 'on') ], [ .10, # P(C = 'on' | A = 'off', B = 'off') .90 # P(C = 'off' | A = 'off', B = 'off') ] ] ]) A = pyro.sample('A', dist.Categorical(probs=prob_A)) B = pyro.sample('B', dist.Categorical(probs=prob_B[A])) C = pyro.sample('C', dist.Categorical(probs=prob_C[A][B])) return C
def uniform_draw(object_list): probs = [] prob = 1/len(object_list) for obj in object_list: probs.append(prob) sample_name = str(uuid.uuid1()) sample = pyro.sample(sample_name, dist.Categorical(probs=torch.Tensor(probs))) return object_list[sample]
def norm_prior(self): prob_tensor = torch.zeros(len(self.agent_norm)) total = sum(self.n_prior) temp_list = [i / total for i in self.n_prior] #normalize for i in range(len(self.agent_norm)): prob_tensor[i] = temp_list[i] n_prior = pyro.sample("norm", dist.Categorical(prob_tensor)) return n_prior
def gmm(): data = torch.tensor([0., 0., 3., 3., 3., 5., 5.]) mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(3))) cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(3.), 1.)) with pyro.plate("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) return cluster_means
def datagen_grp(n_sample, map_est): # initialize storage stor = [] if 'concentration' in map_est.keys(): for i in np.arange(n_sample): grp = dist.Categorical(map_est['weights']).sample() grp_c = map_est['concentration'][grp] samp = dist.Dirichlet(grp_c).sample() stor.append(samp) else: for i in np.arange(n_sample): grp = dist.Categorical(map_est['weights']).sample() grp_a = map_est['alpha'][grp] grp_b = map_est['beta'][grp] samp = dist.Beta(grp_a, grp_b).sample() stor.append(samp) return torch.stack(stor)
def utterance_prior(): utterances = [ "some of the blond people are nice", "all of the blond people are nice", "none of the blond people are nice", ] ix = pyro.sample("utterance", dist.Categorical(torch.ones(3) / 3.0)) return utterances[ix]
def price_prior(): values = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001] probs = torch.tensor([ 0.4205, 0.3865, 0.0533, 0.0538, 0.0223, 0.0211, 0.0112, 0.0111, 0.0083, 0.0120 ]) ix = pyro.sample("price", dist.Categorical(probs=probs)) return values[ix]
def test_value(x_shape, i_shape, j_shape, event_shape): x = torch.rand(x_shape + (5, 6) + event_shape) i = dist.Categorical(torch.ones(5)).sample(i_shape) j = dist.Categorical(torch.ones(6)).sample(j_shape) if event_shape: actual = Vindex(x)[..., i, j, :] else: actual = Vindex(x)[..., i, j] shape = broadcast_shape(x_shape, i_shape, j_shape) x = x.expand(shape + (5, 6) + event_shape) i = i.expand(shape) j = j.expand(shape) expected = x.new_empty(shape + event_shape) for ind in (itertools.product(*map(range, shape)) if shape else [()]): expected[ind] = x[ind + (i[ind].item(), j[ind].item())] assert_equal(actual, expected)
def transition(self, state, action): nextStates = ["bad", "good", "spectacular"] if action == 0: prob = torch.tensor((0.2, 0.6, 0.2)) else: # french prob = torch.tensor((0.05, 0.9, 0.05)) return nextStates[dist.Categorical(prob).sample()]
def gmm_batch_guide(data): with pyro.iarange("data", len(data)) as batch: n = len(batch) ps = pyro.param("ps", Variable(torch.ones(n, 1) * 0.6, requires_grad=True)) ps = torch.cat([ps, 1 - ps], dim=1) z = pyro.sample("z", dist.Categorical(ps)) assert z.size() == (n, 2)
def test_dist_to_funsor_categorical(batch_shape, cardinality): logits = torch.randn(batch_shape + (cardinality, )) logits -= logits.logsumexp(dim=-1, keepdim=True) d = dist.Categorical(logits=logits) f = dist_to_funsor(d) assert isinstance(f, Tensor) expected = tensor_to_funsor(logits, ("value", )) assert_close(f, expected)
def gmm(data): mix_proportions = pyro.sample("phi", dist.Dirichlet(torch.ones(K))) with pyro.plate("num_clusters", K): cluster_means = pyro.sample("cluster_means", dist.Normal(torch.arange(float(K)), 1.)) with pyro.plate("data", data.shape[0]): assignments = pyro.sample("assignments", dist.Categorical(mix_proportions)) pyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data) return cluster_means
def guide(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
def _model(self, data=None, labels=None, batch_size=None): args = self.args # Globals. with pyro.plate("topics", args.num_topics): topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample( "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) label_prior = pyro.sample( "label_prior", dist.Beta(*torch.ones(2, args.num_topics))) # Locals. with pyro.plate("documents", args.num_docs) as ind: if data is not None: with pyro.util.ignore_jit_warnings(): assert data.shape == (args.num_words_per_doc, args.num_docs) data = data[:, ind] if labels is not None: with pyro.util.ignore_jit_warnings(): assert labels.shape == (args.num_docs, args.num_topics) labels = labels[ind] labels = pyro.sample("labels", dist.Bernoulli(label_prior).to_event(1), obs=labels) auxiliary = pyro.sample("auxiliary", dist.Gamma(topic_weights, 1).to_event(1)) doc_topics = labels * auxiliary doc_topics = pyro.sample( "doc_topics", dist.Delta(doc_topics / doc_topics.sum(axis=1)[..., None]).to_event(1)) with pyro.plate("words", args.num_words_per_doc): word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), infer={"enumerate": "parallel"}) data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data) return topic_weights, topic_words, data, labels
def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with poutine.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).expand_by( [hidden_dim]).to_event(2), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample( "w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}, ) x = pyro.sample( "x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}, ) with tones_plate as tones: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t], )
def setup_reward_prior(self): rew_prior = {} for reward_index in range(len(self.reward)): rew_prior["reward-" + str(reward_index)] = float( pyro.sample( "reward-" + str(reward_index), dist.Categorical(probs=torch.FloatTensor( self.inferred_reward[reward_index])))) return rew_prior
def model(self, data): with pyro.plate("topics", self.n_topics): alpha = pyro.sample("alpha", dist.Gamma(1. / self.n_topics, 1.)) beta_param = torch.ones(self.vocab_size) / self.vocab_size betas = pyro.sample("beta", dist.Dirichlet(beta_param)) words = [] for d in pyro.plate("doc_loop", len(data)): doc = data[d] theta = pyro.sample(f"theta_{d}", dist.Dirichlet(alpha)) n_words = len(data[d]) for w in pyro.plate(f"word_loop_{d}", n_words): z = pyro.sample(f"z{d}_{w}", dist.Categorical(theta)) w = pyro.sample(f"w{d}_{w}", dist.Categorical(betas[z]), obs=doc[w]) words.append(w) return words
def model(self, x, y): priors = { name: make_normal_prior(p) for name, p in self.net.named_parameters() } lifted_module = pyro.random_module("module", self.net, priors) lifted_reg_model = lifted_module() lhat = F.log_softmax(lifted_reg_model(x), dim=1) pyro.sample("y", pdist.Categorical(logits=lhat), obs=y)
def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.plate('data', len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)
def action_model(self, state): # draw a random action action = dist.Categorical(torch.tensor((0.5, 0.5))).sample() # calculate expected uttility for the action expected_u = self.expected_utility(state, action) # add factor to the action pyro.factor("state_%saction_%d" % (state, action), self.alpha * expected_u) return action
def community_dist_in_range(self) -> dist.Categorical: """ A distribution for the portion of the current normalized community prediction that's within the question's range. :return: distribution on integers referencing 0...(len(self.prediction_histogram)-1) """ y2 = [p[2] for p in self.prediction_histogram] return dist.Categorical(probs=torch.tensor(y2))
def policy(self, action_posterior): """ Agent Policy to select action """ action = pyro.sample( 'action_policy', dist.Categorical(action_posterior) ) return action
def test_categorical_batch_log_pdf_shape(one_hot): ps = ng_ones(3, 2, 4) / 4 if one_hot: x = ng_zeros(3, 2, 4) x[:, :, 0] = 1 else: x = ng_zeros(3, 2, 1) d = dist.Categorical(ps, one_hot=one_hot) assert d.batch_log_pdf(x).size() == (3, 2, 1)
def random_choice(options, ps=None): if ps is None: ps = torch.tensor([1 / len(options)] * len(options)) else: # in case ps are passed in as some array-like type other than torch.tensor ps = torch.tensor(ps) idx = sample(dist.Categorical(ps)) return options[idx]
def _importance_resample(self): # TODO: Turn quadratic algo -> linear algo by being lazier index = dist.Categorical(logits=self._log_weights).sample( sample_shape=(self.num_particles, )) self._values = { name: value[index].contiguous() for name, value in self._values.items() } self._log_weights.fill_(0.)
def forward(self, x, y_data=None): output = self.fc1(x) output = F.relu(output) output = self.out(output) self.lhat = F.log_softmax(output) obs = pyro.sample("obs", dist.Categorical(logits=self.lhat), obs=y_data) return obs