def _assert_has_all_keys(adata: AnnData, direction: Direction): assert _transition(direction) in adata.uns.keys() if direction == Direction.FORWARD: assert str(LinKey.FORWARD) in adata.obsm assert isinstance(adata.obsm[str(LinKey.FORWARD)], cr.tl.Lineage) assert _colors(LinKey.FORWARD) in adata.uns.keys() assert _lin_names(LinKey.FORWARD) in adata.uns.keys() assert str(StateKey.FORWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(StateKey.FORWARD)]) assert _probs(StateKey.FORWARD) in adata.obs else: assert str(LinKey.BACKWARD) in adata.obsm assert isinstance(adata.obsm[str(LinKey.BACKWARD)], cr.tl.Lineage) assert _colors(LinKey.BACKWARD) in adata.uns.keys() assert _lin_names(LinKey.BACKWARD) in adata.uns.keys() assert str(StateKey.BACKWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(StateKey.BACKWARD)]) assert _probs(StateKey.BACKWARD) in adata.obs
def _check_main_states(mc: cr.tl.GPCCA, has_main_states: bool = True): if has_main_states: assert isinstance(mc.main_states, pd.Series) assert_array_nan_equal(mc.adata.obs[str(StateKey.FORWARD)], mc.main_states) np.testing.assert_array_equal( mc.adata.uns[_colors(StateKey.FORWARD)], mc.lineage_probabilities[list(mc.main_states.cat.categories)].colors, ) assert isinstance(mc.diff_potential, np.ndarray) assert isinstance(mc.lineage_probabilities, cr.tl.Lineage) np.testing.assert_array_equal( mc.adata.obsm[str(LinKey.FORWARD)], mc.lineage_probabilities.X ) np.testing.assert_array_equal( mc.adata.uns[_lin_names(LinKey.FORWARD)], mc.lineage_probabilities.names ) np.testing.assert_array_equal( mc.adata.uns[_colors(LinKey.FORWARD)], mc.lineage_probabilities.colors ) np.testing.assert_array_equal(mc.adata.obs[_dp(LinKey.FORWARD)], mc.diff_potential) np.testing.assert_array_equal( mc.adata.obs[_probs(StateKey.FORWARD)], mc.main_states_probabilities )
def test_compute_approx_normal_run(self, adata_large: AnnData): vk = VelocityKernel(adata_large).compute_transition_matrix() ck = ConnectivityKernel(adata_large).compute_transition_matrix() final_kernel = 0.8 * vk + 0.2 * ck mc = cr.tl.MarkovChain(final_kernel) mc.compute_eig(k=5) mc.compute_approx_rcs(use=2) assert is_categorical_dtype(mc.approx_recurrent_classes) assert mc.approx_recurrent_classes_probabilities is not None assert _colors(RcKey.FORWARD) in mc.adata.uns.keys() assert _probs(RcKey.FORWARD) in mc.adata.obs.keys()
def _read_from_adata( self, g2m_key: Optional[str] = None, s_key: Optional[str] = None, **kwargs ) -> None: if f"eig_{self._direction}" in self._adata.uns.keys(): self._eig = self._adata.uns[f"eig_{self._direction}"] else: logg.debug( f"DEBUG: `eig_{self._direction}` not found. Setting `.eig` to `None`" ) if self._rc_key in self._adata.obs.keys(): self._meta_states = self._adata.obs[self._rc_key] else: logg.debug( f"DEBUG: `{self._rc_key}` not found in `adata.obs`. Setting `.metastable_states` to `None`" ) if _colors(self._rc_key) in self._adata.uns.keys(): self._meta_states_colors = self._adata.uns[_colors(self._rc_key)] else: logg.debug( f"DEBUG: `{_colors(self._rc_key)}` not found in `adata.uns`. " f"Setting `.metastable_states_colors`to `None`" ) if self._lin_key in self._adata.obsm.keys(): lineages = range(self._adata.obsm[self._lin_key].shape[1]) colors = _create_categorical_colors(len(lineages)) self._lin_probs = Lineage( self._adata.obsm[self._lin_key], names=[f"Lineage {i + 1}" for i in lineages], colors=colors, ) self._adata.obsm[self._lin_key] = self._lin_probs else: logg.debug( f"DEBUG: `{self._lin_key}` not found in `adata.obsm`. Setting `.lin_probs` to `None`" ) if _dp(self._lin_key) in self._adata.obs.keys(): self._dp = self._adata.obs[_dp(self._lin_key)] else: logg.debug( f"DEBUG: `{_dp(self._lin_key)}` not found in `adata.obs`. Setting `.diff_potential` to `None`" ) if g2m_key and g2m_key in self._adata.obs.keys(): self._G2M_score = self._adata.obs[g2m_key] else: logg.debug( f"DEBUG: `{g2m_key}` not found in `adata.obs`. Setting `.G2M_score` to `None`" ) if s_key and s_key in self._adata.obs.keys(): self._S_score = self._adata.obs[s_key] else: logg.debug( f"DEBUG: `{s_key}` not found in `adata.obs`. Setting `.S_score` to `None`" ) if _probs(self._rc_key) in self._adata.obs.keys(): self._meta_states_probs = self._adata.obs[_probs(self._rc_key)] else: logg.debug( f"DEBUG: `{_probs(self._rc_key)}` not found in `adata.obs`. " f"Setting `.metastable_states_probs` to `None`" ) if self._lin_probs is not None: if _lin_names(self._lin_key) in self._adata.uns.keys(): self._lin_probs = Lineage( np.array(self._lin_probs), names=self._adata.uns[_lin_names(self._lin_key)], colors=self._lin_probs.colors, ) self._adata.obsm[self._lin_key] = self._lin_probs else: logg.debug( f"DEBUG: `{_lin_names(self._lin_key)}` not found in `adata.uns`. " f"Using default names" ) if _colors(self._lin_key) in self._adata.uns.keys(): self._lin_probs = Lineage( np.array(self._lin_probs), names=self._lin_probs.names, colors=self._adata.uns[_colors(self._lin_key)], ) self._adata.obsm[self._lin_key] = self._lin_probs else: logg.debug( f"DEBUG: `{_colors(self._lin_key)}` not found in `adata.uns`. " f"Using default colors" )
def compute_metastable_states( self, use: Optional[Union[int, Tuple[int], List[int], range]] = None, percentile: Optional[int] = 98, method: str = "kmeans", cluster_key: Optional[str] = None, n_clusters_kmeans: Optional[int] = None, n_neighbors_louvain: int = 20, resolution_louvain: float = 0.1, n_matches_min: Optional[int] = 0, n_neighbors_filtering: int = 15, basis: Optional[str] = None, n_comps: int = 5, scale: bool = False, en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, ) -> None: """ Find approximate recurrent classes in the Markov chain. Filter to obtain recurrent states in left eigenvectors. Cluster to obtain approximate recurrent classes in right eigenvectors. Params ------ use Which or how many first eigenvectors to use as features for clustering/filtering. If `None`, use `eigengap` statistic. percentile Threshold used for filtering out cells which are most likely transient states. Cells which are in the lower :paramref:`percentile` percent of each eigenvector will be removed from the data matrix. method Method to be used for clustering. Must be one of `['louvain', 'kmeans']`. cluster_key If a key to cluster labels is given, :paramref:`metastable_states` will ge associated with these for naming and colors. n_clusters_kmeans If `None`, this is set to :paramref:`use` `+ 1`. n_neighbors_louvain If we use `'louvain'` for clustering cells, we need to build a KNN graph. This is the K parameter for that, the number of neighbors for each cell. resolution_louvain Resolution parameter from the `louvain` algorithm. Should be chosen relatively small. n_matches_min Filters out cells which don't have at least n_matches_min neighbors from the same class. This filters out some cells which are transient but have been misassigned. n_neighbors_filtering Parameter for filtering cells. Cells are filtered out if they don't have at least :paramref:`n_matches_min` neighbors. among their n_neighbors_filtering nearest cells. basis Key from :paramref`adata` `.obsm` to be used as additional features for the clustering. n_comps Number of embedding components to be use. scale Scale to z-scores. Consider using if appending embedding to features. en_cutoff If :paramref:`cluster_key` is given, this parameter determines when an approximate recurrent class will be labelled as *'Unknown'*, based on the entropy of the distribution of cells over transcriptomic clusters. p_thresh If cell cycle scores were provided, a *Wilcoxon rank-sum test* is conducted to identify cell-cycle driven start- or endpoints. If the test returns a positive statistic and a p-value smaller than :paramref:`p_thresh`, a warning will be issued. Returns ------- None Nothing, but updates the following fields: :paramref:`approx_recurrent_classes`. """ if self._eig is None: raise RuntimeError("Compute eigendecomposition first as `.compute_eig()`") start = logg.info("Computing approximate recurrent classes") if method not in ["kmeans", "louvain"]: raise ValueError( f"Invalid method `{method!r}`. Valid options are `'kmeans', 'louvain'`." ) if use is None: use = self._eig["eigengap"] + 1 # add one b/c indexing starts at 0 if isinstance(use, int): use = list(range(use)) elif not isinstance(use, (tuple, list, range)): raise TypeError( f"Argument `use` must be either `int`, `tuple`, `list` or `range`, " f"found `{type(use).__name__}`." ) else: if not all(map(lambda u: isinstance(u, int), use)): raise TypeError("Not all values in `use` argument are integers.") use = list(use) muse = max(use) if muse >= self._eig["V_l"].shape[1] or muse >= self._eig["V_r"].shape[1]: raise ValueError( f"Maximum specified eigenvector ({muse}) is larger " f'than the number of computed eigenvectors ({self._eig["V_l"].shape[1]}). ' f"Use `.compute_eig(k={muse})` to recompute the eigendecomposition." ) logg.debug("DEBUG: Retrieving eigendecomposition") # we check for complex values only in the left, that's okay because the complex pattern # will be identical for left and right V_l, V_r = self._eig["V_l"][:, use], self._eig["V_r"].real[:, use] V_l = _complex_warning(V_l, use, use_imag=False) # compute a rc probability logg.debug("DEBUG: Computing probabilities of approximate recurrent classes") probs = self._compute_metastable_states_prob(use) self._meta_states_probs = probs self._adata.obs[_probs(self._rc_key)] = probs # retrieve embedding and concatenate if basis is not None: if f"X_{basis}" not in self._adata.obsm.keys(): raise KeyError(f"Compute basis `{basis!r}` first.") X_em = self._adata.obsm[f"X_{basis}"][:, :n_comps] X = np.concatenate([V_r, X_em], axis=1) else: logg.debug("DEBUG: Basis is `None`. Setting X equal to right eigenvectors") X = V_r # filter out cells which are in the lowest q percentile in abs value in each eigenvector if percentile is not None: logg.debug("DEBUG: Filtering out cells according to percentile") if percentile < 0 or percentile > 100: raise ValueError( f"Percentile must be in interval `[0, 100]`, found `{percentile}`." ) cutoffs = np.percentile(np.abs(V_l), percentile, axis=0) ixs = np.sum(np.abs(V_l) < cutoffs, axis=1) < V_l.shape[1] X = X[ixs, :] # scale if scale: X = zscore(X, axis=0) # cluster X logg.debug( f"DEBUG: Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering" ) labels = _cluster_X( X, method=method, n_clusters_kmeans=n_clusters_kmeans, percentile=percentile, use=use, n_neighbors_louvain=n_neighbors_louvain, resolution_louvain=resolution_louvain, ) # fill in the labels in case we filtered out cells before if percentile is not None: rc_labels = np.repeat(None, self._adata.n_obs) rc_labels[ixs] = labels else: rc_labels = labels rc_labels = Series(rc_labels, index=self._adata.obs_names, dtype="category") rc_labels.cat.categories = list(rc_labels.cat.categories.astype("str")) # filtering to get rid of some of the left over transient states if n_matches_min > 0: logg.debug("DEBUG: Filtering according to `n_matches_min`") distances = _get_connectivities( self._adata, mode="distances", n_neighbors=n_neighbors_filtering ) rc_labels = _filter_cells( distances, rc_labels=rc_labels, n_matches_min=n_matches_min ) self.set_metastable_states( labels=rc_labels, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, ) logg.info( f"Adding `adata.obs[{_probs(self._rc_key)!r}]`\n" f" `adata.obs[{self._rc_key!r}]`\n" f" `.approx_recurrent_classes_probabilities`\n" f" `.approx_recurrent_classes`\n" f" Finish", time=start, )