def gen_data(self): # sample overall relative abundances of ASVs from a Dirichlet distribution self.ASV_rel_abundance = tdist.Dirichlet(torch.ones( self.numASVs)).sample() # sample spatial embedding of ASVs self.w = torch.zeros(self.numASVs, self.D) w_prior = tdist.MultivariateNormal(torch.zeros(self.D), torch.eye(self.D)) for o in range(0, self.numASVs): self.w[o, :] = w_prior.sample() self.data = torch.zeros(self.numParticles, self.numASVs) num_nonempty = 0 mu_prior = tdist.MultivariateNormal(torch.zeros(self.D), torch.eye(self.D)) rad_prior = tdist.LogNormal(torch.tensor([self.mu_rad]), torch.tensor([self.mu_std])) # replace with neg bin prior num_reads_prior = tdist.Poisson( torch.tensor([self.avgNumReadsParticle])) while (num_nonempty < self.numParticles): # sample center mu = mu_prior.sample() rad = rad_prior.sample() zr = torch.zeros(1, self.numASVs, dtype=torch.float64) for o in range(0, self.numASVs): p = mu - self.w[o, :] p = torch.pow(p, 2.0) / rad p = (torch.sum(p)).sqrt() zr[0, o] = unitboxcar(p, 0.0, 2.0, self.step_approx) if torch.sum(zr) > 0.95: particle = Particle(mu, self) particle.zr = zr self.particles.append(particle) # renormalize particle abundances rn = self.ASV_rel_abundance * zr rn = rn / torch.sum(rn) # sample relative abundances for particle part_rel_abundance = tdist.Dirichlet(rn * self.conc).sample() # sample number of reads for particle # (replace w/ neg bin instead of Poisson) num_reads = num_reads_prior.sample().long().item() particle.total_reads = num_reads particle.reads = tdist.Multinomial( num_reads, probs=part_rel_abundance).sample() num_nonempty += 1
def __init__(self, nb_states, obs_dim, act_dim, prior, norm, device, **kwargs): super(ParametricAugmentationRegressor, self).__init__() self.device = device self.nb_states = nb_states self.obs_dim = obs_dim self.act_dim = act_dim # Dirichlet parameters self.prior = {'alpha': torch.as_tensor(prior['alpha'], dtype=torch.float32, device=self.device), 'kappa': torch.as_tensor(prior['kappa'], dtype=torch.float32, device=self.device)} # Normalization parameters self.norm = {'mean': torch.as_tensor(norm['mean'], dtype=torch.float32, device=self.device), 'std': torch.as_tensor(norm['std'], dtype=torch.float32, device=self.device)} self.dirichlets = [] alphas = self.prior['alpha'] * torch.ones(self.nb_states, dtype=torch.float32, device=self.device) for k in range(self.nb_states): kappas = self.prior['kappa'] * torch.as_tensor(torch.arange(self.nb_states) == k, dtype=torch.float32, device=self.device) self.dirichlets.append(dist.Dirichlet(alphas + kappas, validate_args=True)) self.optim = None
def _expand_node( self, trees: _MCTSTree, n, # n-th expansion, zero-based to_plays, model_output: ModelOutput, dirichlet_alpha=None, exploration_fraction=0.): if self._is_two_player_game: trees.to_play[:, n] = to_plays if trees.game_over is not None: trees.game_over[:, n] = model_output.game_over def _set_tree_state(ts, s): ts[:, n] = s nest.map_structure(_set_tree_state, trees.model_state, model_output.state) if trees.reward is not None: trees.reward[:, n] = model_output.reward if trees.action is not None: trees.action[:, n] = model_output.actions prior = model_output.action_probs if exploration_fraction > 0.: batch_size = model_output.action_probs.shape[0] noise_dist = td.Dirichlet( dirichlet_alpha * torch.ones(trees.branch_factor)) noise = noise_dist.sample((batch_size, )) noise = noise * (prior != 0) noise = noise / noise.sum(dim=1, keepdim=True) prior = exploration_fraction * noise + ( 1 - exploration_fraction) * prior trees.prior[:, n] = prior
def prev(self) -> dist.Distribution: """ Prevalance for each of the categories in a Dirichlet distribution so it adds up to 1. """ return dist.Dirichlet( torch.ones(self.num_categories) * (1.0 / self.num_categories))
def __init__(self, K, D): super().__init__() self.alpha = nn.Parameter(torch.ones(K), requires_grad=True) self.mu = nn.Parameter(torch.randn(K, D)*0.2, requires_grad=True) self.chol = nn.Parameter(torch.stack([torch.eye(D,D)*0.3]*K), requires_grad=True) self.dir = td.Dirichlet(self.alpha)
def encoder(self, fusion_input, enc_hx, enc_cx, lstm, classifier, test=False): # fusion_input = self.feature_extractor(camera_input, sensor_input) enc_hx, enc_cx = lstm(fusion_input, (enc_hx, enc_cx)) enc_score = classifier(enc_hx) if self.dirichlet: if self.method == 'Mean': enc_score_soft = self.softplus(enc_score) dist = distributions.Dirichlet(enc_score_soft) enc_score = dist.mean elif self.method == 'Sample': enc_score_soft = self.softplus(enc_score) dist = distributions.Dirichlet(enc_score_soft) enc_score = dist.rsample() if test: if self.var_method == 'covariance': var = dist.variance diagonal = np.diag(var.cpu().numpy()[0]) con = dist.concentration con0 = con.sum(-1, True) con = con.cpu().numpy()[0] d = (con0.pow(2) * (con0 + 1)) l = len(var.cpu().numpy()[0]) for i in range(l): for j in range(l): if i != j: diagonal[i][j] = -con[i] * con[j] / d return enc_hx, enc_cx, enc_score, diagonal elif self.var_method == 'diagonal': var = dist.variance diagonal = np.diag(var.cpu().numpy()[0]) # print(diagonal) return enc_hx, enc_cx, enc_score, diagonal return enc_hx, enc_cx, enc_score
def test_masked_dirichlet(K=3): mask = make_faces(K) w, con = get_parameters(mask, dir_alpha=None, gamma_alpha=1, gamma_beta=1) p = MaskedDirichlet(mask, con) q = MaskedDirichlet(mask, torch.ones_like(con)) assert (torch.where(torch.logical_not(mask), p.concentration, torch.zeros_like(con)) == 0).all(), "Masked concentration parameters should be 0.0" assert (torch.where(mask, p.concentration, torch.ones_like(con)) > 0).all(), "Unmasked concentration parameters should be strictly positive" for i, face in enumerate(mask): idx = tuple(k for k, b in enumerate(face) if b) alphas = con[i,idx] p_low = td.Dirichlet(alphas) q_low = td.Dirichlet(torch.ones_like(alphas)) assert torch.isclose(p.mean[i,idx], p_low.mean).all(), f"The {i}th face's mean does not match that of td.Dirichlet" assert torch.isclose(p.variance[i,idx], p_low.variance).all(), f"The {i}th face's variance does not match that of td.Dirichlet" assert torch.isclose(p.entropy()[i], p_low.entropy()).all(), f"The {i}th face's entropy does not match that of td.Dirichlet" assert (p.dim[i] == len(idx)), f"The dimensionality of the {i}th face is incorrect: got {p.dim[i]}, expected {len(idx)}" x = p.rsample() assert torch.isclose(p.log_prob(x)[i], p_low.log_prob(x[i,idx])).all(), "The log_prob of a sample does not match that assigned by td.Dirichlet" assert torch.isclose(td.kl_divergence(p, q)[i], td.kl_divergence(p_low, q_low)).all(), "The KL divergence does not match that of td.Dirichlet" assert torch.isclose(td.kl_divergence(q, p)[i], td.kl_divergence(q_low, p_low)).all(), "The KL divergence does not match that of td.Dirichlet"
def test_calculate_exploration_policy(self): dim = 400 batch_size = 1000 tol = 1e-6 dist = td.Dirichlet(torch.full([dim], 0.25)) prior = dist.sample((batch_size, )) value = torch.rand([batch_size, dim]) c = torch.rand([batch_size, 1]) + 0.01 for i in range(10): t = time.time() p, iterations = calculate_exploration_policy(value, prior, c, tol) t = time.time() - t logging.info("time=%s iterations=%s" % (t, iterations)) self.assertTrue(((p.sum(dim=1) - 1).abs() < tol).all())
def find_params(self, data: List[List[str]]) -> List[float]: # phi self.word_topics_distribution = dists.Dirichlet( torch.ones(self.num_topics, self.vocabulary_size)).sample() # theta self.document_topic_distribution = dists.Dirichlet( torch.ones(len(data), self.num_topics)).sample() # z self.topic_assignments = [ dists.Categorical(probas[None].expand( [len(data[i]), self.num_topics])) for i, probas in enumerate(self.document_topic_distribution) ] history = [] for index, document in enumerate(data): self.document_mapping[" ".join(document)] = index for _ in tqdm.trange(self.num_optim_steps, desc="Optim step"): self.run_gibbs_step(data) history.append(self.get_perplexity(data, -1)) return history
def encoder(self, camera_input, sensor_input, enc_hx, enc_cx, enc_score): before_score = enc_score fusion_input = self.feature_extractor(camera_input, sensor_input) enc_hx, enc_cx = self.lstm(fusion_input, (enc_hx, enc_cx)) before_score = before_score.unsqueeze(2) #[32,22,1] before_score = before_score * self.weight #[32,22, 4096] before_score = torch.sum(before_score, dim=1) #[32,4096] hx_enc = torch.add(enc_hx, before_score) enc_score = self.classifier(hx_enc) ## add dirichlet process enc_score_soft = self.softplus(enc_score) dist = distributions.Dirichlet(enc_score_soft) enc_score = dist.mean return enc_hx, enc_cx, enc_score
def forward(self, encoder_out: Dict[str, torch.LongTensor], salience_values) -> Dict[str, torch.Tensor]: mask = encoder_out['source_mask'] seq_len = mask.sum(1) # shape: (batch_size, seq_len, 1) regression_output = self.regression(encoder_out['encoder_outputs']) # Sampling dirichlet alphas = torch.relu(regression_output).squeeze(dim=2) + 1e-6 d_sample = lambda x: D.Dirichlet(x).rsample() loss = [] for idx, alpha in enumerate(alphas): predicted_salience = torch.Tensor(d_sample(alpha[:seq_len[idx]])) loss.append( self._get_loss(predicted_salience, salience_values[idx][:seq_len[idx]])) loss = torch.stack(loss).mean() if torch.isnan(loss): raise ValueError("nan loss encountered") output_dict = {'loss': loss} return output_dict
def pick_cell_types(uni_labels, alpha, min_n_cells): ''' Pick cell types to include in synthetic spots with proportions from Dirichlet distribution. Parameters ---------- uni_labels: np.array unique labels alpha: np.array dirichlet distribution concentration value (can be from cell type proportions in ST) Return ------ tuple of picked cell types and proportions ''' # get number of different # cell types present n_labels = uni_labels.shape[0] # sample number of types to be present at current spot # w/o having more types than cells n_types = dists.uniform.Uniform(low=1, high=min([n_labels, min_n_cells])).sample() n_types = n_types.round().type(t.int) # select which types to include pick_types = t.randperm(n_labels)[0:n_types] alpha = t.Tensor(np.array(alpha[pick_types])) # select cell type proportions member_props = dists.Dirichlet(concentration=alpha * t.ones(n_types)).sample() return ((pick_types, member_props))
def log_prob(self, value): return dists.Dirichlet(self.alpha).log_prob(value)
def encoder(self, camera_input, sensor_input, enc_hx, enc_cx, d_enc_hx, d_enc_cx, enc_score, delta, test=False): before_score = enc_score fusion_input = self.feature_extractor(camera_input, sensor_input) enc_hx, enc_cx = self.lstm_oad(fusion_input, (enc_hx, enc_cx)) d_enc_hx, d_enc_cx = self.lstm_delta(fusion_input, (d_enc_hx, d_enc_cx)) # [32,4096] ## weighted embedding # print(before_score.sum(1)) before_score = before_score.unsqueeze(2) #[32,22,1] before_score = before_score * self.weight #[32,22, 4096] before_score = torch.sum(before_score, dim=1) #[32,4096] hx_enc = torch.add(d_enc_hx, before_score) #[32,4096] # new_enc_hx = torch.add(enc_hx, before_score) # enc_score = self.classifier_oad(new_enc_hx) enc_score = self.classifier_oad(enc_hx) delta_score = self.classifier_delta(hx_enc) delta_var = self.classifier_deltav(hx_enc) if self.dirichlet: if self.method == 'Mean': enc_score_soft = self.softplus(enc_score) dist = distributions.Dirichlet(enc_score_soft) enc_score = dist.mean elif self.method == 'Sample': enc_score_soft = self.softplus(enc_score) dist = distributions.Dirichlet(enc_score_soft) enc_score = dist.rsample() ### diagonal delta_var_soft = self.softplus(delta_var) diagonal = [] for i in range(len(delta_var_soft)): diag = torch.diag(delta_var_soft[i]) diagonal.append(diag) diagonal = torch.stack(diagonal) norm_dist = distributions.MultivariateNormal(delta_score, diagonal) delta_score = norm_dist.rsample() if self.loss_method == 'state_before': var = dist.variance enc_var = [torch.diag(var_i) for var_i in var] #(32,22,22) enc_var = torch.stack(enc_var, dim=0) delta_var = norm_dist.covariance_matrix #(32,22,22) if test: if self.var_method == 'covariance': con = dist.concentration con0 = con.sum(-1, True) d = (con0.pow(2) * (con0 + 1)) con = con.cpu().numpy()[0] con_s = np.reshape(con, (22, 1)) con_t = np.reshape(con, (1, 22)) diagonal = -con_s * con_t diagonal /= d.cpu().numpy()[0] var = dist.variance.cpu().numpy()[0] np.fill_diagonal(diagonal, var) return enc_hx, enc_cx, enc_score, diagonal elif self.var_method == 'diagonal': if delta == False: var = dist.variance # print('OAD variance') # print(var.cpu().numpy()[0]) enc_diagonal = np.diag(var.cpu().numpy()[0]) return enc_hx, enc_cx, enc_score, enc_diagonal elif delta == True: # print('DELTA') delta_socre = norm_dist.mean delta_vari = norm_dist.variance delta_cov = np.diag(delta_vari.cpu().numpy()[0]) # print('DELTA variance') # print(delta_vari.cpu().numpy()[0] ) # delta_cov = norm_dist.covariance_matrix # delta_cov = delta_cov.reshape((22,22)) return d_enc_hx, d_enc_cx, delta_score, delta_cov if self.loss_method == 'oad_before': return enc_hx, enc_cx, enc_score, d_enc_hx, d_enc_cx, delta_score elif self.loss_method == 'state_before': return enc_hx, enc_cx, enc_score, enc_var, d_enc_hx, d_enc_cx, delta_score, delta_var
new.lambda1 = self.lambda1.expand(lambda1_shape) new.lambda2 = self.lambda2.expand(lambda2_shape) super(NaturalNormalWishart, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new if __name__ == '__main__': N, K, D = 1000, 3, 2 mean = torch.zeros(D) nu = torch.tensor(1.) a = torch.tensor(float(D) - 1.) B = torch.eye(D) niw = NaturalNormalWishart.from_standard(mean, nu, a, B) data = torch.randn(N, D) post_niw = niw.posterior(data) print(post_niw.to_standard()) mix = td.Dirichlet(torch.ones(K)) weights = mix.sample((N, )) expanded_niw = niw.expand((K, )) post_niw = expanded_niw.posterior(data) print(post_niw.to_standard()) post_niw = expanded_niw.posterior(data, weights) print(post_niw.to_standard()) samples = niw.rsample((K, )) print(samples) print(samples.batch_shape, samples.event_shape)
def entropy(self): return dists.Dirichlet(self.alpha).entropy()
def confusion_matrix(self, j: int, c: int) -> dist.Distribution: """ Confusion matrix for each labeler (j) and category (c), where each row is a Dirichlet distribution. """ return dist.Dirichlet(self.alpha[c])
def sample(self, batch_size): return dists.Dirichlet(self.alpha).rsample((batch_size, ))
def __init__(self, a): dist = dists.Dirichlet(a[0]) super().__init__(dist, "Dirichlet", -1, a)
def init_parameters(self, alpha=1.): for x in self.parameters(recurse=False): dirich = distr.Dirichlet(th.tensor([alpha] * x.shape[-1])) x.data = dirich.sample(x.shape[:-1]).log()
def _assemble_spot( cnt: np.ndarray, labels: np.ndarray, alpha: float = 1.0, fraction: float = 0.1, bounds: List[int] = [10, 30], ) -> Dict[str, t.Tensor]: """Assemble single spot generates one synthetic ST-spot from provided single cell data Parameter: --------- cnt : np.ndarray single cell count data [n_cells x n_genes] labels : np.ndarray single cell annotations [n_cells] alpha : float dirichlet distribution concentration value fraction : float fraction of transcripts from each cell being observed in ST-spot Returns: ------- Dictionary with expression data, proportion values and number of cells from each type at every spot """ # sample between 10 to 30 cells to be present # at spot n_cells = dists.uniform.Uniform( low=bounds[0], high=bounds[1]).sample().round().type(t.int) # get unique labels found in single cell data uni_labs, uni_counts = np.unique(labels, return_counts=True) # make sure sufficient number # of cells are present within # all cell types assert np.all(uni_counts >= 30), \ "Insufficient number of cells" # get number of different # cell types present n_labels = uni_labs.shape[0] # sample number of types to # be present at current spot n_types = dists.uniform.Uniform(low=1, high=n_labels).sample() n_types = n_types.round().type(t.int) # select which types to include pick_types = t.randperm(n_labels)[0:n_types] # pick at least one cell for spot members = t.zeros(n_labels).type(t.float) while members.sum() < 1: # draw proportion values from probability simplex member_props = dists.Dirichlet(concentration=alpha * t.ones(n_types)).sample() # get integer number of cells based on proportions members[pick_types] = (n_cells * member_props).round() # get proportion of each type props = members / members.sum() # convert to ints members = members.type(t.int) # get number of cells from each cell type # generate spot expression data spot_expr = t.zeros(cnt.shape[1]).type(t.float32) for z in range(n_types): # get indices of selected type idx = np.where(labels == uni_labs[pick_types[z]])[0] # pick random cells from type np.random.shuffle(idx) idx = idx[0:members[pick_types[z]]] # add fraction of transcripts to spot expression spot_expr += t.tensor( (cnt[idx, :] * fraction).sum(axis=0).round().astype(np.float32)) return { 'expr': spot_expr, 'proportions': props, 'members': members, }