def _check_abs_probs(mc: cr.tl.estimators.GPCCA, has_main_states: bool = True): if has_main_states: assert isinstance(mc._get(P.TERM), pd.Series) assert_array_nan_equal(mc.adata.obs[str(TermStatesKey.FORWARD)], mc._get(P.TERM)) np.testing.assert_array_equal( mc.adata.uns[_colors(TermStatesKey.FORWARD)], mc._get(A.TERM_ABS_PROBS)[list(mc._get( P.TERM).cat.categories)].colors, ) assert isinstance(mc._get(P.PRIME_DEG), pd.Series) assert isinstance(mc._get(P.ABS_PROBS), cr.tl.Lineage) np.testing.assert_array_almost_equal(mc._get(P.ABS_PROBS).sum(1), 1.0) np.testing.assert_array_equal(mc.adata.obsm[str(AbsProbKey.FORWARD)], mc._get(P.ABS_PROBS).X) np.testing.assert_array_equal(mc.adata.uns[_lin_names(AbsProbKey.FORWARD)], mc._get(P.ABS_PROBS).names) np.testing.assert_array_equal(mc.adata.uns[_colors(AbsProbKey.FORWARD)], mc._get(P.ABS_PROBS).colors) np.testing.assert_array_equal(mc.adata.obs[_pd(AbsProbKey.FORWARD)], mc._get(P.PRIME_DEG)) assert_array_nan_equal(mc.adata.obs[TermStatesKey.FORWARD.s], mc._get(P.TERM)) np.testing.assert_array_equal(mc.adata.obs[_probs(TermStatesKey.FORWARD)], mc._get(P.TERM_PROBS))
def _check_abs_probs(mc: cr.tl.estimators.GPCCA, has_main_states: bool = True): if has_main_states: assert isinstance(mc._get(P.FIN), pd.Series) assert_array_nan_equal(mc.adata.obs[str(FinalStatesKey.FORWARD)], mc._get(P.FIN)) np.testing.assert_array_equal( mc.adata.uns[_colors(FinalStatesKey.FORWARD)], mc._get(A.FIN_ABS_PROBS)[list(mc._get( P.FIN).cat.categories)].colors, ) assert isinstance(mc._get(P.DIFF_POT), pd.Series) assert isinstance(mc._get(P.ABS_PROBS), cr.tl.Lineage) np.testing.assert_array_almost_equal(mc._get(P.ABS_PROBS).sum(1), 1.0) np.testing.assert_array_equal(mc.adata.obsm[str(AbsProbKey.FORWARD)], mc._get(P.ABS_PROBS).X) np.testing.assert_array_equal(mc.adata.uns[_lin_names(AbsProbKey.FORWARD)], mc._get(P.ABS_PROBS).names) np.testing.assert_array_equal(mc.adata.uns[_colors(AbsProbKey.FORWARD)], mc._get(P.ABS_PROBS).colors) np.testing.assert_array_equal(mc.adata.obs[_dp(AbsProbKey.FORWARD)], mc._get(P.DIFF_POT)) assert_array_nan_equal(mc.adata.obs[FinalStatesKey.FORWARD.s], mc._get(P.FIN)) np.testing.assert_array_equal(mc.adata.obs[_probs(FinalStatesKey.FORWARD)], mc._get(P.FIN_PROBS))
def _write_terminal_states(self, time=None) -> None: self.adata.obs[self._term_key] = self._get(P.TERM) self.adata.obs[_probs(self._term_key)] = self._get(P.TERM_PROBS) self.adata.uns[_colors(self._term_key)] = self._get(A.TERM_COLORS) self.adata.uns[_lin_names(self._term_key)] = np.array( self._get(P.TERM).cat.categories ) extra_msg = "" if getattr(self, A.TERM_ABS_PROBS.s, None) is not None and hasattr( self, "_term_abs_prob_key" ): # checking for None because terminal states can be set using `set_terminal_states` # without the probabilities in GPCCA self.adata.obsm[self._term_abs_prob_key] = self._get(A.TERM_ABS_PROBS) extra_msg = f" `adata.obsm[{self._term_abs_prob_key!r}]`\n" logg.info( f"Adding `adata.obs[{_probs(self._term_key)!r}]`\n" f" `adata.obs[{self._term_key!r}]`\n" f"{extra_msg}" f" `.{P.TERM_PROBS}`\n" f" `.{P.TERM}`\n" " Finish", time=time, )
def test_compute_initial_states_from_forward_normal_run( self, adata_large: AnnData): vk = VelocityKernel( adata_large, backward=False).compute_transition_matrix(softmax_scale=4) ck = ConnectivityKernel(adata_large, backward=False).compute_transition_matrix() terminal_kernel = 0.8 * vk + 0.2 * ck mc = cr.tl.estimators.GPCCA(terminal_kernel) mc.compute_schur(n_components=10, method="krylov") mc.compute_macrostates(n_states=2, n_cells=5) obsm_keys = set(mc.adata.obsm.keys()) expected = mc._get(P.COARSE_STAT_D).index[np.argmin( mc._get(P.COARSE_STAT_D))] mc._compute_initial_states(1) key = TermStatesKey.BACKWARD.s assert key in mc.adata.obs np.testing.assert_array_equal(mc.adata.obs[key].cat.categories, [expected]) assert _probs(key) in mc.adata.obs assert _colors(key) in mc.adata.uns assert _lin_names(key) in mc.adata.uns # make sure that we don't write anything there - it's useless assert set(mc.adata.obsm.keys()) == obsm_keys
def _check_renaming_no_write_terminal(mc: cr.tl.estimators.GPCCA) -> None: assert mc._get(P.TERM) is None assert mc._get(P.TERM_PROBS) is None assert mc._get(A.TERM_ABS_PROBS) is None assert TermStatesKey.FORWARD.s not in mc.adata.obs assert _probs(TermStatesKey.FORWARD.s) not in mc.adata.obs assert _colors(TermStatesKey.FORWARD.s) not in mc.adata.uns assert _lin_names(TermStatesKey.FORWARD.s) not in mc.adata.uns
def _assert_has_all_keys(adata: AnnData, direction: Direction): assert _transition(direction) in adata.obsp.keys() # check if it's not a dummy transition matrix assert not np.all( np.isclose(np.diag(adata.obsp[_transition(direction)].A), 1.0)) assert f"{_transition(direction)}_params" in adata.uns.keys() if direction == Direction.FORWARD: assert str(AbsProbKey.FORWARD) in adata.obsm assert isinstance(adata.obsm[str(AbsProbKey.FORWARD)], cr.tl.Lineage) assert _colors(AbsProbKey.FORWARD) in adata.uns.keys() assert _lin_names(AbsProbKey.FORWARD) in adata.uns.keys() assert str(TermStatesKey.FORWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(TermStatesKey.FORWARD)]) assert _probs(TermStatesKey.FORWARD) in adata.obs # check the correlations with all lineages have been computed lin_probs = adata.obsm[str(AbsProbKey.FORWARD)] np.in1d( [f"{str(DirPrefix.FORWARD)} {key}" for key in lin_probs.names], adata.var.keys(), ).all() else: assert str(AbsProbKey.BACKWARD) in adata.obsm assert isinstance(adata.obsm[str(AbsProbKey.BACKWARD)], cr.tl.Lineage) assert _colors(AbsProbKey.BACKWARD) in adata.uns.keys() assert _lin_names(AbsProbKey.BACKWARD) in adata.uns.keys() assert str(TermStatesKey.BACKWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(TermStatesKey.BACKWARD)]) assert _probs(TermStatesKey.BACKWARD) in adata.obs # check the correlations with all lineages have been computed lin_probs = adata.obsm[str(AbsProbKey.BACKWARD)] np.in1d( [f"{str(DirPrefix.BACKWARD)} {key}" for key in lin_probs.names], adata.var.keys(), ).all()
def _read_from_adata(self) -> None: self._set_or_debug(f"eig_{self._direction}", self.adata.uns, "_eig") self._set_or_debug(self._g2m_key, self.adata.obs, "_G2M_score") self._set_or_debug(self._s_key, self.adata.obs, "_S_score") self._set_or_debug(self._term_key, self.adata.obs, A.TERM.s) self._set_or_debug(_probs(self._term_key), self.adata.obs, A.TERM_PROBS) self._set_or_debug(_colors(self._term_key), self.adata.uns, A.TERM_COLORS) self._reconstruct_lineage(A.ABS_PROBS, self._abs_prob_key) self._set_or_debug(_pd(self._abs_prob_key), self.adata.obs, A.PRIME_DEG)
def _read_from_adata(self) -> None: self._set_or_debug(f"eig_{self._direction}", self.adata.uns, "_eig") self._set_or_debug(self._g2m_key, self.adata.obs, "_G2M_score") self._set_or_debug(self._s_key, self.adata.obs, "_S_score") self._set_or_debug(self._fs_key, self.adata.obs, A.FIN.s) self._set_or_debug(_probs(self._fs_key), self.adata.obs, A.FIN_PROBS) self._set_or_debug(_colors(self._fs_key), self.adata.uns, A.FIN_COLORS) self._reconstruct_lineage(A.ABS_PROBS, self._abs_prob_key) self._set_or_debug(_dp(self._abs_prob_key), self.adata.obs, A.DIFF_POT)
def test_compute_approx_normal_run(self, adata_large: AnnData): vk = VelocityKernel(adata_large).compute_transition_matrix(softmax_scale=4) ck = ConnectivityKernel(adata_large).compute_transition_matrix() terminal_kernel = 0.8 * vk + 0.2 * ck mc = cr.tl.estimators.CFLARE(terminal_kernel) mc.compute_eigendecomposition(k=5) mc.compute_terminal_states(use=2) assert is_categorical_dtype(mc._get(P.TERM)) assert mc._get(P.TERM_PROBS) is not None assert TermStatesKey.FORWARD.s in mc.adata.obs.keys() assert _probs(TermStatesKey.FORWARD) in mc.adata.obs.keys() assert _colors(TermStatesKey.FORWARD) in mc.adata.uns.keys()
def _write_initial_states(self, membership: Lineage, probs: pd.Series, cats: pd.Series, time=None) -> None: key = TermStatesKey.BACKWARD.s self.adata.obs[key] = cats self.adata.obs[_probs(key)] = probs self.adata.uns[_colors(key)] = membership.colors self.adata.uns[_lin_names(key)] = membership.names logg.info( f"Adding `adata.obs[{_probs(key)!r}]`\n `adata.obs[{key!r}]`\n", time=time, )
def _write_final_states(self, time=None) -> None: self.adata.obs[self._fs_key] = self._get(P.FIN) self.adata.obs[_probs(self._fs_key)] = self._get(P.FIN_PROBS) self.adata.uns[_colors(self._fs_key)] = self._get(A.FIN_COLORS) self.adata.uns[_lin_names(self._fs_key)] = list(self._get(P.FIN).cat.categories) extra_msg = "" if getattr(self, A.FIN_ABS_PROBS.s, None) is not None and hasattr( self, "_fin_abs_prob_key" ): # checking for None because final states can be set using `set_final_states` # without the probabilities in GPCCA self.adata.obsm[self._fin_abs_prob_key] = self._get(A.FIN_ABS_PROBS) extra_msg = f" `adata.obsm[{self._fin_abs_prob_key!r}]`\n" logg.info( f"Adding `adata.obs[{_probs(self._fs_key)!r}]`\n" f" `adata.obs[{self._fs_key!r}]`\n" f"{extra_msg}" f" `.{P.FIN_PROBS}`\n" f" `.{P.FIN}`", time=time, )
class GPCCA(BaseEstimator, Macrostates, Schur, Eigen): """ Generalized Perron Cluster Cluster Analysis :cite:`reuter:18` as implemented in `pyGPCCA <https://pygpcca.readthedocs.io/en/latest/>`_. Coarse-grains a discrete Markov chain into a set of macrostates and computes coarse-grained transition probabilities among the macrostates. Each macrostate corresponds to an area of the state space, i.e. to a subset of cells. The assignment is soft, i.e. each cell is assigned to every macrostate with a certain weight, where weights sum to one per cell. Macrostates are computed by maximizing the 'crispness' which can be thought of as a measure for minimal overlap between macrostates in a certain inner-product sense. Once the macrostates have been computed, we project the large transition matrix onto a coarse-grained transition matrix among the macrostates via a Galerkin projection. This projection is based on invariant subspaces of the original transition matrix which are obtained using the real Schur decomposition :cite:`reuter:18`. Parameters ---------- %(base_estimator.parameters)s """ # noqa: E501 __prop_metadata__ = [ Metadata( attr=A.COARSE_T, prop=P.COARSE_T, compute_fmt=F.NO_FUNC, plot_fmt=F.NO_FUNC, dtype=pd.DataFrame, doc="Coarse-grained transition matrix.", ), Metadata(attr=A.TERM_ABS_PROBS, prop=P.NO_PROPERTY, dtype=Lineage), Metadata(attr=A.COARSE_INIT_D, prop=P.COARSE_INIT_D, dtype=pd.Series), Metadata(attr=A.COARSE_STAT_D, prop=P.COARSE_STAT_D, dtype=pd.Series), ] def _read_from_adata(self) -> None: super()._read_from_adata() self._reconstruct_lineage( A.TERM_ABS_PROBS, self._term_abs_prob_key, ) @inject_docs( ms=P.MACRO, msp=P.MACRO_MEMBER, schur=P.SCHUR.s, coarse_T=P.COARSE_T, coarse_stat=P.COARSE_STAT_D, ) @d.dedent def compute_macrostates( self, n_states: Optional[Union[int, Tuple[int, int], List[int], Dict[str, int]]] = None, n_cells: Optional[int] = 30, use_min_chi: bool = False, cluster_key: str = None, en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, ): """ Compute the macrostates. Parameters ---------- n_states Number of macrostates. If `None`, use the `eigengap` heuristic. %(n_cells)s use_min_chi Whether to use :meth:`pygpcca.GPCCA.minChi` to calculate the number of macrostates. If `True`, ``n_states`` corresponds to a closed interval `[min, max]` inside of which the potentially optimal number of macrostates is searched. cluster_key If a key to cluster labels is given, names and colors of the states will be associated with the clusters. %(en_cutoff_p_thresh)s Returns ------- None Nothing, but updates the following fields: - :attr:`{msp}` - :attr:`{ms}` - :attr:`{schur}` - :attr:`{coarse_T}` - :attr:`{coarse_stat}` """ was_from_eigengap = False if use_min_chi: n_states = self._get_n_states_from_minchi(n_states) if n_states is None: if self._get(P.EIG) is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()` or `.compute_schur()`." ) was_from_eigengap = True n_states = self._get(P.EIG)["eigengap"] + 1 logg.info(f"Using `{n_states}` states based on eigengap") elif not isinstance(n_states, int): raise ValueError( f"Expected `n_states` to be an integer when `use_min_chi=False`, " f"found `{type(n_states).__name__!r}`.") if n_states <= 0: raise ValueError( f"Expected `n_states` to be positive or `None`, found `{n_states}`." ) n_states = self._check_states_validity(n_states) if n_states == 1: self._compute_one_macrostate( n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) return if self._gpcca is None: if not was_from_eigengap: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()`.") logg.warning( f"Number of states `{n_states}` was automatically determined by `eigengap` " "but no Schur decomposition was found. Computing with default parameters" ) # this cannot fail if splitting occurs # if it were to split, it's automatically increased in `compute_schur` self.compute_schur(n_states) # pre-computed X if self._gpcca._p_X.shape[1] < n_states: logg.warning( f"Requested more macrostates `{n_states}` than available " f"Schur vectors `{self._gpcca._p_X.shape[1]}`. Recomputing the decomposition" ) start = logg.info(f"Computing `{n_states}` macrostates") try: self._gpcca = self._gpcca.optimize(m=n_states) except ValueError as e: # this is the following case - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev. # in the try block, Schur decomposition with 5 vectors is computed, but it fails (no way of knowing) # so in this case, we increase it by 1 n_states += 1 logg.warning(f"{e}\nIncreasing `n_states` to `{n_states}`") self._gpcca = self._gpcca.optimize(m=n_states) self._set_macrostates( memberships=self._gpcca.memberships, n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) # cache the results and make sure we don't overwrite self._set(A.SCHUR, self._gpcca._p_X) self._set(A.SCHUR_MAT, self._gpcca._p_R) names = self._get(P.MACRO_MEMBER).names self._set( A.COARSE_T, pd.DataFrame( self._gpcca.coarse_grained_transition_matrix, index=names, columns=names, ), ) self._set( A.COARSE_INIT_D, pd.Series(self._gpcca.coarse_grained_input_distribution, index=names), ) # careful here, in case computing the stat. dist failed if self._gpcca.coarse_grained_stationary_probability is not None: self._set( A.COARSE_STAT_D, pd.Series( self._gpcca.coarse_grained_stationary_probability, index=names, ), ) logg.info( f"Adding `.{P.MACRO_MEMBER}`\n" f" `.{P.MACRO}`\n" f" `.{P.SCHUR}`\n" f" `.{P.COARSE_T}`\n" f" `.{P.COARSE_STAT_D}`\n" f" Finish", time=start, ) else: logg.warning("No stationary distribution found in GPCCA object") logg.info( f"Adding `.{P.MACRO_MEMBER}`\n" f" `.{P.MACRO}`\n" f" `.{P.SCHUR}`\n" f" `.{P.COARSE_T}`\n" f" Finish", time=start, ) @d.dedent @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS) def set_terminal_states_from_macrostates( self, names: Optional[Union[Sequence[str], Mapping[str, str], str]] = None, n_cells: int = 30, ): """ Manually select terminal states from macrostates. Parameters ---------- names Names of the macrostates to be marked as terminal. Multiple states can be combined using `','`, such as ``["Alpha, Beta", "Epsilon"]``. If a :class:`dict`, keys correspond to the names of the macrostates and the values to the new names. If `None`, select all macrostates. %(n_cells)s Returns ------- None Nothing, just updates the following fields: - :attr:`{fsp}` - :attr:`{fs}` """ if not isinstance(n_cells, int): raise TypeError( f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__}`." ) if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.MACRO_MEMBER) if probs is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") rename = True if names is None: names = probs.names rename = False if isinstance(names, str): names = [names] rename = False if not isinstance(names, dict): names = {n: n for n in names} rename = False if not len(names): raise ValueError("No macrostates have been selected.") if not all(isinstance(old, str) for old in names.keys()): raise TypeError("Not all new names are strings.") if not all(isinstance(new, (str, int)) for new in names.values()): raise TypeError( "Not all macrostates names are strings or integers.") # this also checks that the names are correct before renaming macrostates_probs = probs[list(names.keys())] # we do this also here because if `rename_terminal_states` fails # invalid states would've been written to this object and nothing to adata new_names = {k: str(v) for k, v in names.items()} names_after_renaming = [new_names.get(n, n) for n in probs.names] if len(set(names_after_renaming)) != probs.shape[1]: raise ValueError( f"After renaming, the names will not be unique: `{names_after_renaming}`." ) if probs.shape[1] == 1: self._set(A.TERM, self._create_states(probs, n_cells=n_cells)) self._set(A.TERM_COLORS, self._get(A.MACRO_COLORS)) self._set( A.TERM_PROBS, pd.Series(probs.X.squeeze() / probs.X.max(), index=self.adata.obs_names), ) self._set(A.TERM_ABS_PROBS, probs) if rename: # access lineage renames join states, e.g. 'Alpha, Beta' becomes 'Alpha or Beta' + whitespace stripping self.rename_terminal_states( dict(zip(self._get(P.TERM).cat.categories, names.values()))) self._write_terminal_states() return # compute the aggregated probability of being a initial/terminal state (no matter which) scaled_probs = macrostates_probs.copy() scaled_probs /= scaled_probs.max(0) self._set(A.TERM, self._create_states(macrostates_probs, n_cells=n_cells)) self._set(A.TERM_PROBS, pd.Series(scaled_probs.X.max(1), index=self.adata.obs_names)) self._set( A.TERM_COLORS, macrostates_probs[list(self._get(P.TERM).cat.categories)].colors, ) self._set(A.TERM_ABS_PROBS, scaled_probs) if rename: self.rename_terminal_states( dict(zip(self._get(P.TERM).cat.categories, names.values()))) self._write_terminal_states() @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS) @d.dedent def compute_terminal_states( self, method: str = "stability", n_cells: int = 30, alpha: Optional[float] = 1, stability_threshold: float = 0.96, n_states: Optional[int] = None, ): """ Automatically select terminal states from macrostates. Parameters ---------- method One of following: - `'eigengap'` - select the number of states based on the `eigengap` of the transition matrix. - `'eigengap_coarse'` - select the number of states based on the `eigengap` of the diagonal of the coarse-grained transition matrix. - `'top_n'` - select top ``n_states`` based on the probability of the diagonal of the coarse-grained transition matrix. - `'stability'` - select states which have a stability index >= ``stability_threshold``. The stability index is given by the diagonal elements of the coarse-grained transition matrix. %(n_cells)s alpha Weight given to the deviation of an eigenvalue from one. Used when ``method='eigengap'`` or ``method='eigengap_coarse'``. stability_threshold Threshold used when ``method='stability'``. n_states Numer of states used when ``method='top_n'``. Returns ------- None Nothing, just updates the following fields: - :attr:`{fsp}` - :attr:`{fs}` """ if len(self._get(P.MACRO).cat.categories) == 1: logg.warning( "Found only one macrostate. Making it the single main state") self.set_terminal_states_from_macrostates(None, n_cells=n_cells) return coarse_T = self._get(P.COARSE_T) if method == "eigengap": if self._get(P.EIG) is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()`." ) n_states = _eigengap(self._get(P.EIG)["D"], alpha=alpha) + 1 elif method == "eigengap_coarse": if coarse_T is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") n_states = _eigengap(np.sort(np.diag(coarse_T)[::-1]), alpha=alpha) elif method == "top_n": if n_states is None: raise ValueError( "Argument `n_states` must be != `None` for `method='top_n'`." ) elif n_states <= 0: raise ValueError( f"Expected `n_states` to be positive, found `{n_states}`.") elif method == "stability": if stability_threshold is None: raise ValueError( "Argument `stability_threshold` must be != `None` for `method='stability'`." ) self_probs = pd.Series(np.diag(coarse_T), index=coarse_T.columns) names = self_probs[self_probs.values >= stability_threshold].index self.set_terminal_states_from_macrostates(names, n_cells=n_cells) return else: raise ValueError( f"Invalid method `{method!r}`. Valid options are `'eigengap', 'eigengap_coarse', " f"'top_n' and 'min_self_prob'`.") names = coarse_T.columns[np.argsort(np.diag(coarse_T))][-n_states:] self.set_terminal_states_from_macrostates(names, n_cells=n_cells) def compute_gdpt(self, n_components: int = 10, key_added: str = "gdpt_pseudotime", **kwargs): """ Compute generalized Diffusion pseudotime from :cite:`haghverdi:16` using the real Schur decomposition. Parameters ---------- n_components Number of real Schur vectors to consider. key_added Key in :attr:`adata` ``.obs`` where to save the pseudotime. kwargs Keyword arguments for :meth:`cellrank.tl.GPCCA.compute_schur` if Schur decomposition is not found. Returns ------- None Nothing, just updates :attr:`adata` ``.obs[key_added]`` with the computed pseudotime. """ def _get_dpt_row(e_vals: np.ndarray, e_vecs: np.ndarray, i: int): row = sum( (np.abs(e_vals[eval_ix]) / (1 - np.abs(e_vals[eval_ix])) * (e_vecs[i, eval_ix] - e_vecs[:, eval_ix]))**2 # account for float32 precision for eval_ix in range(0, e_vals.size) if np.abs(e_vals[eval_ix]) < 0.9994) return np.sqrt(row) if "iroot" not in self.adata.uns.keys(): raise KeyError("Key `'iroot'` not found in `adata.uns`.") iroot = self.adata.uns["iroot"] if isinstance(iroot, str): iroot = np.where(self.adata.obs_names == iroot)[0] if not len(iroot): raise ValueError( f"Unable to find cell with name `{self.adata.uns['iroot']!r}` in `adata.obs_names`." ) iroot = iroot[0] if n_components < 2: raise ValueError( f"Expected number of components >= 2, found `{n_components}`.") if self._get(P.SCHUR) is None: logg.warning("No Schur decomposition found. Computing") self.compute_schur(n_components, **kwargs) elif self._get(P.SCHUR_MAT).shape[1] < n_components: logg.warning( f"Requested `{n_components}` components, but only `{self._get(P.SCHUR_MAT).shape[1]}` were found. " f"Recomputing using default values") self.compute_schur(n_components) else: logg.debug("Using cached Schur decomposition") start = logg.info( f"Computing Generalized Diffusion Pseudotime using `n_components={n_components}`" ) Q, eigenvalues = ( self._get(P.SCHUR), self._get(P.EIG)["D"], ) # may have to remove some values if too many converged Q, eigenvalues = Q[:, :n_components], eigenvalues[:n_components] D = _get_dpt_row(eigenvalues, Q, i=iroot) pseudotime = D / np.max(D[np.isfinite(D)]) self.adata.obs[key_added] = pseudotime logg.info(f"Adding `{key_added!r}` to `adata.obs`\n Finish", time=start) @d.dedent def plot_coarse_T( self, show_stationary_dist: bool = True, show_initial_dist: bool = False, cmap: Union[str, mcolors.ListedColormap] = "viridis", xtick_rotation: float = 45, annotate: bool = True, show_cbar: bool = True, title: Optional[str] = None, figsize: Tuple[float, float] = (8, 8), dpi: int = 80, save: Optional[Union[Path, str]] = None, text_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs, ) -> None: """ Plot the coarse-grained transition matrix between macrostates. Parameters ---------- show_stationary_dist Whether to show the stationary distribution, if present. show_initial_dist Whether to show the initial distribution. cmap Colormap to use. xtick_rotation Rotation of ticks on the x-axis. annotate Whether to display the text on each cell. show_cbar Whether to show colorbar. title Title of the figure. %(plotting)s text_kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. kwargs Keyword arguments for :func:`matplotlib.pyplot.imshow`. Returns ------- %(just_plots)s """ def stylize_dist(ax, data: np.ndarray, xticks_labels: Union[List[str], Tuple[str]] = ()): _ = ax.imshow(data, aspect="auto", cmap=cmap, norm=norm) for spine in ax.spines.values(): spine.set_visible(False) if xticks_labels is not None: ax.set_xticks(np.arange(data.shape[1])) ax.set_xticklabels(xticks_labels) plt.setp( ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode="anchor", ) else: ax.set_xticks([]) ax.tick_params(which="both", top=False, right=False, bottom=False, left=False) ax.set_yticks([]) def annotate_heatmap(im, valfmt: str = "{x:.2f}"): # modified from matplotlib's site data = im.get_array() kw = {"ha": "center", "va": "center"} kw.update(**text_kwargs) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = mpl.ticker.StrMethodFormatter(valfmt) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update( color=_get_black_or_white(im.norm(data[i, j]), cmap)) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"): if ax is None: return if isinstance(valfmt, str): valfmt = mpl.ticker.StrMethodFormatter(valfmt) kw = {"ha": "center", "va": "center"} kw.update(**text_kwargs) for i, val in enumerate(data): kw.update(color=_get_black_or_white(im.norm(val), cmap)) ax.text( i, 0, valfmt(val, None), **kw, ) coarse_T = self._get(P.COARSE_T) coarse_stat_d = self._get(P.COARSE_STAT_D) coarse_init_d = self._get(P.COARSE_INIT_D) if coarse_T is None: raise RuntimeError( "Compute coarse-grained transition matrix first as `.compute_macrostates()` with `n_states > 1`." ) if show_stationary_dist and coarse_stat_d is None: logg.warning("Coarse stationary distribution is `None`, ignoring") show_stationary_dist = False if show_initial_dist and coarse_init_d is None: logg.warning("Coarse initial distribution is `None`, ignoring") show_initial_dist = False hrs, wrs = [1], [1] if show_stationary_dist: hrs += [0.05] if show_initial_dist: hrs += [0.05] if show_cbar: wrs += [0.025] dont_show_dist = not show_initial_dist and not show_stationary_dist fig = plt.figure(constrained_layout=False, figsize=figsize, dpi=dpi) gs = plt.GridSpec( 1 + show_stationary_dist + show_initial_dist, 1 + show_cbar, height_ratios=hrs, width_ratios=wrs, wspace=0.05, hspace=0.05, ) if isinstance(cmap, str): cmap = plt.get_cmap(cmap) ax = fig.add_subplot(gs[0, 0]) cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None init_ax, stat_ax = None, None labels = list(self.coarse_T.columns) tmp = coarse_T if show_initial_dist: tmp = np.c_[tmp, coarse_stat_d] if show_initial_dist: tmp = np.c_[tmp, coarse_init_d] minn, maxx = np.nanmin(tmp), np.nanmax(tmp) norm = mpl.colors.Normalize(vmin=minn, vmax=maxx) if show_stationary_dist: stat_ax = fig.add_subplot(gs[1, 0]) stylize_dist( stat_ax, np.array(coarse_stat_d).reshape(1, -1), xticks_labels=labels if not show_initial_dist else None, ) stat_ax.yaxis.set_label_position("right") stat_ax.set_ylabel("stationary dist", rotation=0, ha="left", va="center") if show_initial_dist: init_ax = fig.add_subplot(gs[show_stationary_dist + show_initial_dist, 0]) stylize_dist(init_ax, np.array(coarse_init_d).reshape(1, -1), xticks_labels=labels) init_ax.yaxis.set_label_position("right") init_ax.set_ylabel("initial dist", rotation=0, ha="left", va="center") im = ax.imshow(coarse_T, aspect="auto", cmap=cmap, norm=norm, **kwargs) ax.set_title( "coarse-grained transition matrix" if title is None else title) if cax is not None: _ = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, ticks=np.linspace(minn, maxx, 10), format="%0.3f", ) ax.set_yticks(np.arange(coarse_T.shape[0])) ax.set_yticklabels(labels) ax.tick_params( top=False, bottom=dont_show_dist, labeltop=False, labelbottom=dont_show_dist, ) for spine in ax.spines.values(): spine.set_visible(False) if dont_show_dist: ax.set_xticks(np.arange(coarse_T.shape[1])) ax.set_xticklabels(labels) plt.setp( ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode="anchor", ) else: ax.set_xticks([]) ax.set_yticks(np.arange(coarse_T.shape[0] + 1) - 0.5, minor=True) ax.tick_params(which="minor", bottom=dont_show_dist, left=False, top=False) if annotate: annotate_heatmap(im) if show_stationary_dist: annotate_dist_ax(stat_ax, coarse_stat_d.values) if show_initial_dist: annotate_dist_ax(init_ax, coarse_init_d) if save: save_fig(fig, save) @d.dedent def plot_macrostate_composition( self, key: str, width: float = 0.8, title: Optional[str] = None, labelrot: float = 45, legend_loc: Optional[str] = "upper right out", figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, show: bool = True, ) -> Optional[Axes]: """ Plot stacked histogram of macrostates over categorical annotations. Parameters ---------- %(adata)s key Key from :attr:`adata` ``.obs`` containing categorical annotations. width Bar width in `[0, 1]`. title Title of the figure. If `None`, create one automatically. labelrot Rotation of labels on x-axis. legend_loc Position of the legend. If `None`, don't show legend. %(plotting)s show If `False`, return :class:`matplotlib.pyplot.Axes`. Returns ------- :class:`matplotlib.pyplot.Axes` The axis object if ``show=False``. %(just_plots)s """ from cellrank.pl._utils import _position_legend macrostates = self._get(P.MACRO) if macrostates is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") if key not in self.adata.obs: raise KeyError(f"Key `{key}` not found in `adata.obs`.") if not is_categorical_dtype(self.adata.obs[key]): raise TypeError( f"Expected `adata.obs[{key!r}]` to be `categorical`, " f"found `{infer_dtype(self.adata.obs[key])}`.") mask = ~macrostates.isnull() df = (pd.DataFrame({ "macrostates": macrostates, key: self.adata.obs[key] })[mask].groupby([key, "macrostates"]).size()) try: cats_colors = self.adata.uns[f"{key}_colors"] except KeyError: cats_colors = _create_categorical_colors( len(self.adata.obs[key].cat.categories)) cat_color_mapper = dict( zip(self.adata.obs[key].cat.categories, cats_colors)) x_indices = np.arange(len(macrostates.cat.categories)) bottom = np.zeros_like(x_indices, dtype=np.float32) width = min(1, max(0, width)) fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True) for cat, color in cat_color_mapper.items(): frequencies = df.loc[cat] # do not add to legend if category is missing if np.sum(frequencies) > 0: ax.bar( x_indices, frequencies, width, label=cat, color=color, bottom=bottom, ec="black", lw=0.5, ) bottom += np.array(frequencies) ax.set_xticks(x_indices) ax.set_xticklabels( # assuming at least 1 category frequencies.index, rotation=labelrot, ha="center" if labelrot in (0, 90) else "right", ) y_max = bottom.max() ax.set_ylim([0, y_max + 0.05 * y_max]) ax.set_yticks(np.linspace(0, y_max, 5)) ax.margins(0.05) ax.set_xlabel("macrostate") ax.set_ylabel("frequency") if title is None: title = f"distribution over {key}" ax.set_title(title) if legend_loc not in (None, "none"): _position_legend(ax, legend_loc=legend_loc) if save is not None: save_fig(fig, save) if not show: return ax def _compute_one_macrostate( self, n_cells: int, cluster_key: Optional[str], en_cutoff: Optional[float], p_thresh: float, ) -> None: start = logg.warning( "For 1 macrostate, stationary distribution is computed") eig = self._get(P.EIG) if (eig is not None and "stationary_dist" in eig and eig["params"]["which"] == "LR"): stationary_dist = eig["stationary_dist"] else: self.compute_eigendecomposition(only_evals=False, which="LR") stationary_dist = self._get(P.EIG)["stationary_dist"] self._set_macrostates( memberships=stationary_dist[:, None], n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) self._set( A.MACRO_MEMBER, Lineage( stationary_dist, names=list(self._get(A.MACRO).cat.categories), colors=self._get(A.MACRO_COLORS), ), ) # reset all the things for key in ( A.ABS_PROBS, A.PRIME_DEG, A.SCHUR, A.SCHUR_MAT, A.COARSE_T, A.COARSE_STAT_D, A.COARSE_STAT_D, ): self._set(key.s, None) logg.info( f"Adding `.{P.MACRO_MEMBER}`\n `.{P.MACRO}`\n Finish", time=start, ) def _get_n_states_from_minchi( self, n_states: Union[Tuple[int, int], List[int], Dict[str, int]]) -> int: if self._gpcca is None: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`." ) if not isinstance(n_states, (dict, tuple, list)): raise TypeError( f"Expected `n_states` to be either `dict`, `tuple` or a `list`, " f"found `{type(n_states).__name__}`.") if len(n_states) != 2: raise ValueError( f"Expected `n_states` to be of size `2`, found `{len(n_states)}`." ) if isinstance(n_states, dict): if "min" not in n_states or "max" not in n_states: raise KeyError( f"Expected the dictionary to have `'min'` and `'max'` keys, " f"found `{tuple(n_states.keys())}`.") minn, maxx = n_states["min"], n_states["max"] else: minn, maxx = n_states if minn > maxx: logg.debug( f"Swapping minimum and maximum because `{minn}` > `{maxx}`") minn, maxx = maxx, minn if minn <= 1: raise ValueError(f"Minimum value must be > `1`, found `{minn}`.") elif minn == 2: logg.warning( "In most cases, 2 clusters will always be optimal. " "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`" ) minn = 3 if minn >= maxx: maxx = minn + 1 logg.debug( f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`" ) logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`") return int( np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn, maxx))]) @d.dedent def _set_macrostates( self, memberships: np.ndarray, n_cells: Optional[int] = 30, cluster_key: str = "clusters", en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, check_row_sums: bool = True, ) -> None: """ Map fuzzy clustering to pre-computed annotations to get names and colors. Given the fuzzy clustering, we would like to select the most likely cells from each state and use these to give each state a name and a color by comparing with pre-computed, categorical cluster annotations. Parameters ---------- memberships Fuzzy clustering. %(n_cells)s cluster_key Key from :attr:`adata` ``.obs`` to get reference cluster annotations. en_cutoff Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost evenly across the reference clusters. p_thresh Only used to detect cell cycle stages. These have to be present in :attr:`adata` ``.obs`` as `'G2M_score'` and `'S_score'`. check_row_sums Check whether rows in `memberships` sum to `1`. Returns ------- None Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors. Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state. """ if n_cells is None: logg.debug("Setting the macrostates using macrostate assignment") # fmt: off max_assignment = np.argmax(memberships, axis=1) _macro_assignment = pd.Series(index=self.adata.obs_names, data=max_assignment, dtype="category") # sometimes, the assignment can have a missing category and the Lineage creation therefore fails # keep it as ints when `n_cells != None` _macro_assignment = _macro_assignment.cat.set_categories( list(range(memberships.shape[1]))) macrostates = _macro_assignment.astype(str).astype( "category").copy() not_enough_cells = [] # fmt: on else: logg.debug("Setting the macrostates using macrostates memberships") # select the most likely cells from each macrostate macrostates, not_enough_cells = self._create_states( memberships, n_cells=n_cells, check_row_sums=check_row_sums, return_not_enough_cells=True, ) not_enough_cells = not_enough_cells.astype("str") # _set_categorical_labels creates the names, we still need to remap the group names orig_cats = macrostates.cat.categories self._set_categorical_labels( attr_key=A.MACRO.v, color_key=A.MACRO_COLORS.v, pretty_attr_key=P.MACRO.v, add_to_existing_error_msg= "Compute macrostates first as `.compute_macrostates()`.", categories=macrostates, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, ) name_mapper = dict(zip(orig_cats, self._get(P.MACRO).cat.categories)) _print_insufficient_number_of_cells( [name_mapper.get(n, n) for n in not_enough_cells], n_cells) logg.debug( "Setting macrostates memberships based on GPCCA membership vectors" ) self._set( A.MACRO_MEMBER, Lineage( memberships, names=list(macrostates.cat.categories), colors=self._get(A.MACRO_COLORS), ), ) def _create_states( self, probs: Union[np.ndarray, Lineage], n_cells: int, check_row_sums: bool = False, return_not_enough_cells: bool = False, ) -> pd.Series: if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") a_discrete, not_enough_cells = _fuzzy_to_discrete( a_fuzzy=probs, n_most_likely=n_cells, remove_overlap=False, raise_threshold=0.2, check_row_sums=check_row_sums, ) states = _series_from_one_hot_matrix( membership=a_discrete, index=self.adata.obs_names, names=probs.names if isinstance(probs, Lineage) else None, ) return (states, not_enough_cells) if return_not_enough_cells else states def _check_states_validity(self, n_states: int) -> int: if self._invalid_n_states is not None and n_states in self._invalid_n_states: logg.warning( f"Unable to compute macrostates with `n_states={n_states}` because it will " f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`" ) n_states += 1 # cannot force recomputation of the Schur decomposition assert n_states not in self._invalid_n_states, "Sanity check failed." return n_states def _fit_terminal_states( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, method: str = "krylov", **kwargs, ) -> None: if n_lineages is None or n_lineages == 1: self.compute_eigendecomposition() if n_lineages is None: n_lineages = self.eigendecomposition["eigengap"] + 1 if n_lineages > 1: self.compute_schur(n_lineages, method=method) try: self.compute_macrostates(n_states=n_lineages, cluster_key=cluster_key, **kwargs) except ValueError: logg.warning( f"Computing `{n_lineages}` macrostates cuts through a block of complex conjugates. " f"Increasing `n_lineages` to {n_lineages + 1}") self.compute_macrostates(n_states=n_lineages + 1, cluster_key=cluster_key, **kwargs) fs_kwargs = { "n_cells": kwargs["n_cells"] } if "n_cells" in kwargs else {} if n_lineages is None: self.compute_terminal_states(method="eigengap", **fs_kwargs) else: self.set_terminal_states_from_macrostates(**fs_kwargs) @d.dedent # because of fit @d.dedent @inject_docs( ms=P.MACRO, msp=P.MACRO_MEMBER, fs=P.TERM, fsp=P.TERM_PROBS, ap=P.ABS_PROBS, pd=P.PRIME_DEG, ) def fit( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, keys: Optional[Sequence[str]] = None, method: str = "krylov", compute_absorption_probabilities: bool = True, **kwargs, ): """ Run the pipeline, computing the macrostates, %(initial_or_terminal)s states \ and optionally the absorption probabilities. It is equivalent to running:: if n_lineages is None or n_lineages == 1: compute_eigendecomposition(...) # get the stationary distribution if n_lineages > 1: compute_schur(...) compute_macrostates(...) if n_lineages is None: compute_terminal_states(...) else: set_terminal_states_from_macrostates(...) if compute_absorption_probabilities: compute_absorption_probabilities(...) Parameters ---------- %(fit)s method Method to use when computing the Schur decomposition. Valid options are: `'krylov'` or `'brandts'`. compute_absorption_probabilities Whether to compute the absorption probabilities or only the %(initial_or_terminal)s states. kwargs Keyword arguments for :meth:`cellrank.tl.estimators.GPCCA.compute_macrostates`. Returns ------- None Nothing, just makes available the following fields: - :attr:`{msp}` - :attr:`{ms}` - :attr:`{fsp}` - :attr:`{fs}` - :attr:`{ap}` - :attr:`{pd}` """ super().fit( n_lineages=n_lineages, cluster_key=cluster_key, keys=keys, method=method, compute_absorption_probabilities=compute_absorption_probabilities, **kwargs, ) @d.dedent def _compute_initial_states(self, n_states: int = 1, n_cells: int = 30) -> None: """ Compute initial states from macrostates. Parameters ---------- n_states Number of initial states. %(n_cells)s Returns ------- %(set_initial_states_from_macrostates.returns)s """ if n_states <= 0: raise ValueError( f"Expected `n_states` to be positive, found `{n_states}`.") if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.MACRO_MEMBER) if probs is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") if n_states > probs.shape[1]: raise ValueError( f"Requested `{n_states}` initial states, but only `{probs.shape[1]}` macrostates have been computed." ) if probs.shape[1] == 1: self._set_initial_states_from_macrostates(n_cells=n_cells) return stat_dist = self._get(P.COARSE_STAT_D) if stat_dist is None: raise RuntimeError( "No coarse-grained stationary distribution found.") self._set_initial_states_from_macrostates( stat_dist[np.argsort(stat_dist)][:n_states].index, n_cells=n_cells) @d.get_sections(base="set_initial_states_from_macrostates", sections=["Returns"]) @d.dedent @inject_docs(key=TermStatesKey.BACKWARD.s, probs_key=_probs(TermStatesKey.BACKWARD.s)) def _set_initial_states_from_macrostates( self, names: Optional[Union[Iterable[str], str]] = None, n_cells: int = 30, ) -> None: """ Manually select initial states from macrostates. Note that no check is performed to ensure initial and terminal states are distinct. Parameters ---------- names Names of the macrostates to be marked as initial states. Multiple states can be combined using `','`, such as `["Alpha, Beta", "Epsilon"]`. %(n_cells)s Returns ------- None Nothing, just writes to :attr:`adata`: - ``.obs[{key!r}]`` - probability of being an initial state. - ``.obs[{probs_key!r}]`` - top ``n_cells`` from each initial state. """ if not isinstance(n_cells, int): raise TypeError( f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__!r}`." ) if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.MACRO_MEMBER) if probs is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") elif probs.shape[1] == 1: categorical = self._create_states(probs, n_cells=n_cells) scaled = probs / probs.max() else: if names is None: names = probs.names if isinstance(names, str): names = [names] probs = probs[list(names)] categorical = self._create_states(probs, n_cells=n_cells) probs /= probs.max(0) # compute the aggregated probability of being a initial/terminal state (no matter which) scaled = probs.X.max(1) self._write_initial_states(membership=probs, probs=scaled, cats=categorical) def _write_initial_states(self, membership: Lineage, probs: pd.Series, cats: pd.Series, time=None) -> None: key = TermStatesKey.BACKWARD.s self.adata.obs[key] = cats self.adata.obs[_probs(key)] = probs self.adata.uns[_colors(key)] = membership.colors self.adata.uns[_lin_names(key)] = membership.names logg.info( f"Adding `adata.obs[{_probs(key)!r}]`\n `adata.obs[{key!r}]`\n", time=time, ) def _write_terminal_states(self, time=None) -> None: super()._write_terminal_states(time=time) term_abs_probs = self._get(A.TERM_ABS_PROBS) if term_abs_probs is None: # possibly remove previous value if it's inconsistent term_abs_probs = self.adata.obsm.get(self._term_abs_prob_key, None) if term_abs_probs is not None: new = list(self._get(P.TERM).cat.categories) old = list(term_abs_probs.names) if term_abs_probs.shape[1] == len(new) and new == old: self.adata.obsm[self._term_abs_prob_key] = term_abs_probs else: logg.warning( f"Removing previously computed `adata.obsm[{self._term_abs_prob_key!r}]` because the " f"names mismatch `{new}` (new), `{old}` (old).") self._set(A.TERM_ABS_PROBS, None) self.adata.obsm.pop(self._term_abs_prob_key, None)
def test_find_final(self, adata: AnnData): cr.tl.terminal_states(adata, n_states=5, fit_kwargs=dict(n_cells=5)) assert str(FinalStatesKey.FORWARD) in adata.obs.keys() assert _probs(FinalStatesKey.FORWARD) in adata.obs.keys()
def test_find_root(self, adata: AnnData): cr.tl.initial_states(adata) assert str(FinalStatesKey.BACKWARD) in adata.obs.keys() assert _probs(FinalStatesKey.BACKWARD) in adata.obs.keys()