def loss( self, tensors, inference_outputs, generative_outputs, n_obs: int = 1.0, ): # generative_outputs is a dict of the return value from `generative(...)` # assume that `n_obs` is the number of training data points p_x_c = generative_outputs["p_x_c"] gamma = generative_outputs["gamma"] # compute Q # take mean of number of cells and multiply by n_obs (instead of summing n) q_per_cell = torch.sum(gamma * -p_x_c, 1) # third term is log prob of prior terms in Q theta_log = F.log_softmax(self.theta_logit, dim=-1) theta_log_prior = Dirichlet(self.dirichlet_concentration) theta_log_prob = -theta_log_prior.log_prob( torch.exp(theta_log) + THETA_LOWER_BOUND) prior_log_prob = theta_log_prob delta_log_prior = Normal(self.delta_log_mean, self.delta_log_log_scale.exp().sqrt()) delta_log_prob = torch.masked_select( delta_log_prior.log_prob(self.delta_log), (self.rho > 0)) prior_log_prob += -torch.sum(delta_log_prob) loss = (torch.mean(q_per_cell) * n_obs + prior_log_prob) / n_obs return LossRecorder(loss, q_per_cell, torch.zeros_like(q_per_cell), prior_log_prob)
def select_action(self, obs): concentration, value = self.forward(obs) m = Dirichlet(concentration) action = m.sample() self.saved_actions.append(SavedAction(m.log_prob(action), value)) return list(action.cpu().numpy())
def test_dirichlet_shape(self): dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]])) self.assertEqual(dist._batch_shape, torch.Size((3,))) self.assertEqual(dist._event_shape, torch.Size((2,))) self.assertEqual(dist.sample().size(), torch.Size((3, 2))) self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2))) self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
def test_dirichlet_log_prob(self): num_samples = 10 alpha = torch.exp(torch.randn(5)) dist = Dirichlet(alpha) x = dist.sample((num_samples,)) actual_log_prob = dist.log_prob(x) for i in range(num_samples): expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy()) self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
def log_prob(self, dist: Dirichlet, samples) -> torch.Tensor: return dist.log_prob(samples)
def forward(self, sentences, sentence_length, input_conversation_length, target_sentences, decode=False): """ Args: sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ batch_size = input_conversation_length.size(0) num_sentences = sentences.size(0) - batch_size max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(sentences, sentence_length) # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] encoder_hidden = encoder_hidden.transpose( 1, 0).contiguous().view(num_sentences + batch_size, -1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1] + 1)), 0) # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size] encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden_inference = encoder_hidden[:, 1:, :] encoder_hidden_inference_flat = torch.cat( [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden_input = encoder_hidden[:, :-1, :] # context_outputs: [batch_size, max_len, context_size] context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input, input_conversation_length) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) alpha_prior = self.prior(context_outputs) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) if not decode: alpha_posterior = self.posterior( context_outputs, encoder_hidden_inference_flat) # resample of dirichlet # z_sent = mu_posterior + torch.sqrt(var_posterior) * eps if torch.cuda.is_available(): alpha_posterior = alpha_posterior.cpu() dirichlet_dist = Dirichlet(alpha_posterior) z_sent = dirichlet_dist.rsample() if torch.cuda.is_available(): z_sent = to_var(z_sent) alpha_posterior = to_var(alpha_posterior) # this two variable log_q_zx and log_p_z is not necessary here # log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum() # log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() # log_q_zx = dirichlet_logpdf(z_sent, alpha_posterior).sum() # log_p_z = dirichlet_logpdf(z_sent, alpha_prior).sum() # print(" ") log_q_zx = dirichlet_dist.log_prob(z_sent.cpu()).sum().cuda() log_p_z = Dirichlet(alpha_prior.cpu()).log_prob(z_sent.cpu()).sum().cuda() # print(log_q_zx.item(), " ", post_z.item()) # print(log_p_z.item(), " ", prior_z.item()) # kl_div: [num_sentneces] # kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior, var_prior) kl_div = dirichlet_kl_div(alpha_posterior, alpha_prior) kl_div = torch.sum(kl_div) else: # z_sent = mu_prior + torch.sqrt(var_prior) * eps if torch.cuda.is_available(): alpha_prior = alpha_prior.cpu() dirichlet_dist = Dirichlet(alpha_prior) z_sent = dirichlet_dist.rsample() if torch.cuda.is_available(): z_sent = z_sent.cuda() alpha_prior = alpha_prior.cuda() kl_div = None # log_p_z = dirichlet_logpdf(z_sent, mu_prior, var_prior).sum() log_p_z = dirichlet_logpdf(z_sent, alpha_prior).sum() log_q_zx = None self.z_sent = z_sent latent_context = torch.cat([context_outputs, z_sent], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(-1, self.decoder.num_layers, self.decoder.hidden_size) decoder_init = decoder_init.transpose(1, 0).contiguous() # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return decoder_outputs, kl_div, log_p_z, log_q_zx else: # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) return prediction, kl_div, log_p_z, log_q_zx
class CellTypeModel(AstirModel): """Class to perform statistical inference to assign cells to cell types. :param dset: the input gene expression dataframe :param random_seed: the random seed for parameter initialization, defaults to 1234 :param dtype: the data type of parameters, should be the same as `dset`, defaults to torch.float64 """ def __init__( self, dset: Optional[SCDataset] = None, random_seed: int = 1234, dtype: torch.dtype = torch.float64, device: torch.device = torch.device("cpu"), ) -> None: super().__init__(dset, random_seed, dtype, device) if dset is not None: self._param_init() def _param_init(self) -> None: """Initializes parameters and design matrices.""" if self._dset is None: raise Exception("the dataset is not provided") G = self._dset.get_n_features() C = self._dset.get_n_classes() self._recog = TypeRecognitionNet(self._dset.get_n_classes(), self._dset.get_n_features()).to( self._device, dtype=self._dtype) # Establish data self._data: Dict[str, torch.Tensor] = { # "log_alpha": torch.log(torch.ones(C + 1, dtype=self._dtype) / (C + 1)).to( # self._device # ), "rho": self._dset.get_marker_mat().to(self._device), } self._alpha_prior = Dirichlet( torch.ones(C + 1, dtype=self._dtype).to(self._device) * (C + 1)) # Initialize mu, log_delta delta_init_mean = torch.log( torch.log(torch.tensor(3.0, dtype=self._dtype)) ) # the log of the log of this is the multiplier t = torch.distributions.Normal( # delta_init_mean.clone().detach().to(self._dtype), torch.tensor(0, dtype=self._dtype), torch.tensor(0.1, dtype=self._dtype), ) log_delta_init = t.sample((G, C + 1)) mu_init = torch.log( torch.tensor(self._dset.get_mu_init(), dtype=self._dtype)).to(self._device) # mu_init = torch.log(self._dset.get_mu()).to(self._device) # mu_init = mu_init - ( # self._data["rho"] * torch.exp(log_delta_init).to(self._device) # ).mean(1) mu_init = mu_init.reshape(-1, 1) # Create initialization dictionary initializations = { "mu": mu_init, "log_sigma": torch.log(self._dset.get_sigma()).to(self._device), "log_delta": log_delta_init, "p": torch.zeros((G, C + 1), dtype=self._dtype, device=self._device), "alpha_logits": torch.ones(C + 1, dtype=self._dtype, device=self._device), } P = self._dset.get_design().shape[1] # Add additional columns of mu for anything in the design matrix initializations["mu"] = torch.cat( [ initializations["mu"], torch.zeros( (G, P - 1), dtype=self._dtype, device=self._device), ], 1, ) # Create trainable variables self._variables: Dict[str, torch.Tensor] = {} for (n, v) in initializations.items(): self._variables[n] = Variable(v.clone()).to(self._device) self._variables[n].requires_grad = True def load_hdf5(self, hdf5_name: str) -> None: """Initializes Cell Type Model from a hdf5 file type :param hdf5_name: file path """ self._assignment = pd.read_hdf(hdf5_name, "celltype_model/celltype_assignments") with h5py.File(hdf5_name, "r") as f: grp = f["celltype_model"] param = grp["parameters"] self._variables = { "mu": torch.tensor(np.array(param["mu"])), "log_sigma": torch.tensor(np.array(param["log_sigma"])), "log_delta": torch.tensor(np.array(param["log_delta"])), "p": torch.tensor(np.array(param["p"])), "alpha_logits": torch.tensor(np.array(param["alpha_logits"])), } self._data = { "rho": torch.tensor(np.array(param["rho"])), } self._losses = torch.tensor(np.array(grp["losses"]["losses"])) rec = grp["recog_net"] hidden1_W = torch.tensor(np.array(rec["hidden_1.weight"])) hidden2_W = torch.tensor(np.array(rec["hidden_2.weight"])) state_dict = { "hidden_1.weight": hidden1_W, "hidden_1.bias": torch.tensor(np.array(rec["hidden_1.bias"])), "hidden_2.weight": hidden2_W, "hidden_2.bias": torch.tensor(np.array(rec["hidden_2.bias"])), } state_dict = OrderedDict(state_dict) self._recog = TypeRecognitionNet( hidden2_W.shape[0] - 1, hidden1_W.shape[1], hidden1_W.shape[0]).to(device=self._device, dtype=self._dtype) self._recog.load_state_dict(state_dict) self._recog.eval() def _forward(self, Y: torch.Tensor, X: torch.Tensor, design: torch.Tensor) -> torch.Tensor: """One forward pass. :param Y: a sample from the dataset :param X: normalized sample data :param design: the corresponding row of design matrix :return: the cost (elbo) of the current pass """ if self._dset is None: raise Exception("the dataset is not provided") G = self._dset.get_n_features() C = self._dset.get_n_classes() N = Y.shape[0] Y_spread = Y.view(N, 1, G).repeat(1, C + 1, 1) delta_tilde = torch.exp(self._variables["log_delta"]) mean = delta_tilde * self._data["rho"] mean2 = torch.mm(design, self._variables["mu"].T) ## N x P * P x G mean2 = mean2.view(-1, G, 1).repeat(1, 1, C + 1) mean = mean + mean2 # now do the variance modelling p = torch.sigmoid(self._variables["p"]) sigma = torch.exp(self._variables["log_sigma"]) v1 = (self._data["rho"] * p).T * sigma v2 = torch.pow(sigma, 2) * (1 - torch.pow(self._data["rho"] * p, 2)).T v1 = v1.view(1, C + 1, G, 1).repeat(N, 1, 1, 1) # extra 1 is the "rank" v2 = v2.view(1, C + 1, G).repeat(N, 1, 1) + 1e-6 dist = LowRankMultivariateNormal(loc=torch.exp(mean).permute(0, 2, 1), cov_factor=v1, cov_diag=v2) log_p_y_on_c = dist.log_prob(Y_spread) gamma, log_gamma = self._recog.forward(X) log_alpha = F.log_softmax(self._variables["alpha_logits"], dim=0) alpha = F.softmax(self._variables["alpha_logits"], dim=0) mix_prior = self._alpha_prior.log_prob(alpha) elbo = (gamma * (log_p_y_on_c + log_alpha - log_gamma)).sum() + mix_prior return -elbo def fit( self, max_epochs: int = 50, learning_rate: float = 1e-3, batch_size: int = 128, delta_loss: float = 1e-3, delta_loss_batch: int = 10, msg: str = "", ) -> None: """Runs train loops until the convergence reaches delta_loss for delta_loss_batch sizes or for max_epochs number of times :param max_epochs: number of train loop iterations, defaults to 50 :param learning_rate: the learning rate, defaults to 0.01 :param batch_size: the batch size, defaults to 128 :param delta_loss: stops iteration once the loss rate reaches delta_loss, defaults to 0.001 :param delta_loss_batch: the batch size to consider delta loss, defaults to 10 :param msg: iterator bar message, defaults to empty string """ if self._dset is None: raise Exception("the dataset is not provided") # Make dataloader dataloader = DataLoader(self._dset, batch_size=min(batch_size, len(self._dset)), shuffle=True) # Run training loop losses: List[torch.Tensor] = [] per = torch.tensor(1) # Construct optimizer opt_params = list(self._variables.values()) + list( self._recog.parameters()) optimizer = torch.optim.Adam(opt_params, lr=learning_rate) _, exprs_X, _ = self._dset[:] # calls dset.get_item iterator = trange( max_epochs, desc="training restart" + msg, unit="epochs", bar_format= "{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]", ) for ep in iterator: # for ep in range(max_epochs): L = None loss = torch.tensor(0.0, dtype=self._dtype) for batch in dataloader: Y, X, design = batch optimizer.zero_grad() L = self._forward(Y, X, design) L.backward() optimizer.step() with torch.no_grad(): loss = loss + L if len(losses) > 0: per = abs((loss - losses[-1]) / losses[-1]) losses.append(loss) iterator.set_postfix_str("current loss: " + str(round(float(loss), 1))) if per <= delta_loss: self._is_converged = True iterator.close() break # Save output self._assignment = pd.DataFrame( self._recog.forward(exprs_X)[0].detach().cpu().numpy()) self._assignment.columns = self._dset.get_classes() + ["Other"] self._assignment.index = self._dset.get_cell_names() if self._losses.shape[0] == 0: self._losses = torch.tensor(losses) else: self._losses = torch.cat((self._losses.view( self._losses.shape[0]), torch.tensor(losses)), dim=0) def predict(self, new_dset: pd.DataFrame) -> np.array: """Feed `new_dset` to the recognition net to get a prediction. :param new_dset: the dataset to be predicted :return: the resulting cell type assignment """ _, exprs_X, _ = new_dset[:] g = pd.DataFrame( self._recog.forward(exprs_X)[0].detach().cpu().numpy()) return g def get_recognet(self) -> TypeRecognitionNet: """Getter for the recognition net. :return: the trained recognition net """ return self._recog def _most_likely_celltype( self, row: pd.DataFrame, threshold: float, cell_types: List[str], assignment_type: str, ) -> str: """Given a row of the assignment matrix, return the most likely cell type :param row: the row of cell assignment matrix to be evaluated :param threshold: the higher bound of the maximun probability to classify a cell as `Unknown` :param cell_types: the names of cell types, in the same order as the features of the row :param assignment_type: See :meth:`astir.CellTypeModel.get_celltypes` for full documentation :return: the most likely cell type of this cell """ row = row.values max_prob = np.max(row) if assignment_type == "threshold": if max_prob < threshold: return "Unknown" elif assignment_type == "max": if sum(row == max_prob) > 1: return "Unknown" return cell_types[np.argmax(row)] def get_celltypes( self, threshold: float = 0.7, assignment_type: str = "threshold", prob_assign: Optional[pd.DataFrame] = None, ) -> pd.DataFrame: """Get the most likely cell types. A cell is assigned to a cell type if the probability is greater than threshold. If no cell types have a probability higher than threshold, then "Unknown" is returned. :param assignment_type: either 'threshold' or 'max'. If threshold, type assignment is based on whether the probability threshold is above prob_assignment. If 'max', type assignment is based on the max probability value or "unknown" if there are multiple max probabilities. Defaults to 'threshold'. :param threshold: the probability threshold above which a cell is assigned to a cell type, defaults to 0.7 :return: a data frame with most likely cell types for each """ if prob_assign is None: type_probability = self.get_assignment() else: type_probability = prob_assign if assignment_type != "threshold" and assignment_type != "max": warnings.warn("Wrong assignment type. Defaults the assignment " "type to threshold.") assignment_type = "threshold" if assignment_type == "max" and prob_assign is not None: warnings.warn("Assignment type is 'max' but probability " "threshold value was passed in. Probability " "threshold value will be ignored.") cell_types = list(type_probability.columns) cell_type_assignments = type_probability.apply( self._most_likely_celltype, axis=1, assignment_type=assignment_type, threshold=threshold, cell_types=cell_types, ) cell_type_assignments = pd.DataFrame(cell_type_assignments) cell_type_assignments.columns = ["cell_type"] return cell_type_assignments def _compare_marker_between_types( self, curr_type: str, celltype_to_compare: str, marker: str, cell_types: List[str], alpha: float = 0.05, ) -> Optional[dict]: """For two cell types and a protein, ensure marker is expressed at higher level for curr_type than celltype_to_compare :param curr_type: the cell type to assess :param celltype_to_compare: all the cell types that shouldn't highly express this marker :param marker: the marker protein for curr_type :param cell_types: list of cell types assigned for cells :param alpha: :return: """ if self._dset is None: raise Exception("the dataset is not provided") current_marker_ind = np.array(self._dset.get_features()) == marker cells_x = np.array(cell_types) == curr_type cells_y = np.array(cell_types) == celltype_to_compare # x - cells whose cell types' marker protein is marker # y - cells whose cell types' marker protein is not marker x = self._dset.get_exprs().detach().cpu().numpy()[cells_x, current_marker_ind] y = self._dset.get_exprs().detach().cpu().numpy()[cells_y, current_marker_ind] stat = np.NaN pval = np.Inf note: Optional[str] = "Only 1 cell in a type: comparison not possible" if len(x) > 1 and len(y) > 1: tt = stats.ttest_ind(x, y) stat = tt.statistic pval = tt.pvalue note = None if not (stat > 0 and pval < alpha): rdict = { "current_marker": marker, "curr_type": curr_type, "celltype_to_compare": celltype_to_compare, "mean_A": x.mean(), "mean_Y": y.mean(), "p-val": pval, "note": note, } return rdict return None def plot_clustermap( self, plot_name: str = "celltype_protein_cluster.png", threshold: float = 0.7, figsize: Tuple[float, float] = (7.0, 5.0), prob_assign: Optional[pd.DataFrame] = None, ) -> None: """Save the heatmap of protein content in cells with cell types labeled. :param plot_name: name of the plot, extension(e.g. .png or .jpg) is needed, defaults to "celltype_protein_cluster.png" :param threshold: the probability threshold above which a cell is assigned to a cell type, defaults to 0.7 :param figsize: the size of the figure, defaults to (7.0, 5.0) """ if self._dset is None: raise Exception("the dataset is not provided") expr_df = self._dset.get_exprs_df() scaler = StandardScaler() for feature in expr_df.columns: expr_df[feature] = scaler.fit_transform( expr_df[feature].values.reshape( (expr_df[feature].shape[0], 1))) expr_df["cell_type"] = self.get_celltypes(threshold=threshold, prob_assign=prob_assign) expr_df = expr_df.sort_values(by=["cell_type"]) types = expr_df.pop("cell_type") types_uni = types.unique() lut = dict(zip(types_uni, sns.color_palette("BrBG", len(types_uni)))) col_colors = pd.DataFrame(types.map(lut)) cm = sns.clustermap( expr_df.T, xticklabels=False, cmap="vlag", col_cluster=False, col_colors=col_colors, figsize=figsize, ) for t in types_uni: cm.ax_col_dendrogram.bar(0, 0, color=lut[t], label=t, linewidth=0) cm.ax_col_dendrogram.legend(title="Cell Types", loc="center", ncol=3, bbox_to_anchor=(0.8, 0.8)) cm.savefig(plot_name, dpi=150) def diagnostics(self, cell_type_assignments: list, alpha: float) -> pd.DataFrame: """Run diagnostics on cell type assignments See :meth:`astir.Astir.diagnostics_celltype` for full documentation """ if self._dset is None: raise Exception("the dataset is not provided") problems = [] # Want to construct a data frame that models rho with # cell type names on the columns and feature names on the rows g_df = pd.DataFrame(self._data["rho"].detach().cpu().numpy()) g_df.columns = self._dset.get_classes() + ["Other"] g_df.index = self._dset.get_features() for curr_type in self._dset.get_classes(): if not curr_type in cell_type_assignments: continue current_markers = g_df.index[g_df[curr_type] == 1] for current_marker in current_markers: # find all the cell types that shouldn't highly express this marker celltypes_to_compare = g_df.columns[g_df.loc[current_marker] == 0] for celltype_to_compare in celltypes_to_compare: if not celltype_to_compare in cell_type_assignments: continue is_problem = self._compare_marker_between_types( curr_type, celltype_to_compare, current_marker, cell_type_assignments, alpha, ) if is_problem is not None: problems.append(is_problem) col_names = [ "feature", "should be expressed higher in", "than", "mean cell type 1", "mean cell type 2", "p-value", "note", ] df_issues = None if len(problems) > 0: df_issues = pd.DataFrame(problems) df_issues.columns = col_names else: df_issues = pd.DataFrame(columns=col_names) return df_issues