class FinalStates(Plottable): """Class dealing with final states.""" __prop_metadata__ = [ Metadata(attr=A.FIN, prop=P.FIN, dtype=pd.Series, doc="Final states."), Metadata( attr=A.FIN_PROBS, prop=P.FIN_PROBS, dtype=pd.Series, doc="Final states probabilities.", ), Metadata(attr=A.FIN_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray), ] @abstractmethod def set_final_states(self, *args, **kwargs) -> None: # noqa pass @abstractmethod def compute_final_states(self, *args, **kwargs) -> None: # noqa pass @abstractmethod def _write_final_states(self, *args, **kwargs) -> None: pass
class TerminalStates(Plottable): """Class dealing with terminal states.""" __prop_metadata__ = [ Metadata(attr=A.TERM, prop=P.TERM, dtype=pd.Series, doc="Terminal states."), Metadata( attr=A.TERM_PROBS, prop=P.TERM_PROBS, dtype=pd.Series, doc="Terminal states probabilities.", ), Metadata(attr=A.TERM_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray), ] @abstractmethod def set_terminal_states(self, *args: Any, **kwargs: Any) -> None: # noqa pass @abstractmethod def compute_terminal_states(self, *args: Any, **kwargs: Any) -> None: # noqa pass @abstractmethod def _write_terminal_states(self, *args: Any, **kwargs: Any) -> None: pass
class Macrostates(Plottable): """Class dealing with macrostates.""" __prop_metadata__ = [ Metadata(attr=A.MACRO, prop=P.MACRO, dtype=pd.Series), Metadata( attr=A.MACRO_MEMBER, prop=P.MACRO_MEMBER, dtype=Lineage, ), Metadata(attr=A.MACRO_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray), ] @abstractmethod def compute_macrostates(self, *args, **kwargs) -> None: # noqa pass
class MetaStates(Plottable): """Class dealing with metastable states.""" __prop_metadata__ = [ Metadata(attr=A.META, prop=P.META, dtype=pd.Series, doc="Metastable states."), Metadata( attr=A.META_PROBS, prop=P.META_PROBS, dtype=Lineage, doc="Metastable states probabilities.", ), Metadata(attr=A.META_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray), ] @abstractmethod def compute_metastable_states(self, *args, **kwargs) -> None: # noqa pass
class LinDrivers(Plottable): # noqa __prop_metadata__ = [ Metadata( attr=A.LIN_DRIVERS, prop=P.LIN_DRIVERS, dtype=pd.DataFrame, doc="Lineage drivers.", plot_fmt=F.NO_FUNC, # in essence ignore Plottable (could be done by registering DataFrame, but it's ugly ) ]
class AbsProbs(Plottable): """Class dealing with absorption probabilities.""" __prop_metadata__ = [ Metadata( attr=A.ABS_PROBS, prop=P.ABS_PROBS, dtype=Lineage, doc="Absorption probabilities.", ), Metadata( attr=A.PRIME_DEG, prop=P.PRIME_DEG, dtype=pd.Series, doc="Priming degree.", ), Metadata(attr=A.LIN_ABS_TIMES, prop=P.LIN_ABS_TIMES, dtype=pd.DataFrame), ] @abstractmethod def _write_absorption_probabilities(self, *args, **kwargs) -> None: pass
class AbsProbs(Plottable): """Class dealing with absorption probabilities.""" __prop_metadata__ = [ Metadata( attr=A.ABS_PROBS, prop=P.ABS_PROBS, dtype=Lineage, doc="Absorption probabilities.", ), Metadata( attr=A.DIFF_POT, prop=P.DIFF_POT, dtype=pd.Series, doc="Differentiation potential.", ), Metadata(attr=A.LIN_ABS_TIMES, prop=P.LIN_ABS_TIMES, dtype=pd.DataFrame), ] @abstractmethod def _write_absorption_probabilities(self, *args, **kwargs) -> None: pass
class Eigen(VectorPlottable, Decomposable): """Class computing the eigendecomposition.""" __prop_metadata__ = [ Metadata(attr=A.EIG, prop=P.EIG, dtype=Mapping[str, Any], compute_fmt=F.NO_FUNC) ] @d.dedent @inject_docs(prop=P.EIG) def compute_eigendecomposition( self, k: int = 20, which: str = "LR", alpha: float = 1, only_evals: bool = False, ncv: Optional[int] = None, ) -> None: """ Compute eigendecomposition of transition matrix. Uses a sparse implementation, if possible, and only computes the top :math:`k` eigenvectors to speed up the computation. Computes both left and right eigenvectors. Parameters ---------- k Number of eigenvalues/vectors to compute. %(eigen)s only_evals Compute only eigenvalues. ncv Number of Lanczos vectors generated. Returns ------- None Nothing, but updates the following field: - :paramref:`{prop}` """ def get_top_k_evals(): return D[np.flip(np.argsort(D.real))][:k] start = logg.info( "Computing eigendecomposition of the transition matrix") if self.issparse: logg.debug(f"Computing top `{k}` eigenvalues for sparse matrix") D, V_l = eigs(self.transition_matrix.T, k=k, which=which, ncv=ncv) if only_evals: self._write_eig_to_adata({ "D": get_top_k_evals(), "eigengap": _eigengap(get_top_k_evals().real, alpha), "params": { "which": which, "k": k, "alpha": alpha }, }) return _, V_r = eigs(self.transition_matrix, k=k, which=which, ncv=ncv) else: logg.warning( "This transition matrix is not sparse, computing full eigendecomposition" ) D, V_l = np.linalg.eig(self.transition_matrix.T) if only_evals: self._write_eig_to_adata({ "D": get_top_k_evals(), "eigengap": _eigengap(D.real, alpha), "params": { "which": which, "k": k, "alpha": alpha }, }) return _, V_r = np.linalg.eig(self.transition_matrix) # Sort the eigenvalues and eigenvectors and take the real part logg.debug("Sorting eigenvalues by their real part") p = np.flip(np.argsort(D.real)) D, V_l, V_r = D[p], V_l[:, p], V_r[:, p] e_gap = _eigengap(D.real, alpha) pi = np.abs(V_l[:, 0].real) pi /= np.sum(pi) self._write_eig_to_adata( { "D": D, "stationary_dist": pi, "V_l": V_l, "V_r": V_r, "eigengap": e_gap, "params": { "which": which, "k": k, "alpha": alpha }, }, start=start, ) @d.dedent def plot_eigendecomposition(self, left: bool = False, *args, **kwargs): """ Plot eigenvectors in an embedding. Parameters ---------- left Whether to plot left or right eigenvectors. %(plot_vectors.parameters)s Returns ------- %(plot_vectors.returns)s """ eig = getattr(self, P.EIG.s) if eig is None: self._plot_vectors(None, P.EIG.s) side = "left" if left else "right" D, V = ( eig["D"], eig.get(f"V_{side[0]}", None), ) if V is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition(..., only_evals=False)`." ) # if irreducible, first rigth e-vec should be const. if side == "right": # quick check for irreducibility: if np.sum(np.isclose(D, 1, rtol=1e2 * EPS, atol=1e2 * EPS)) == 1: V[:, 0] = 1.0 self._plot_vectors( V, P.EIG.s, *args, D=D, **kwargs, ) @d.dedent def plot_spectrum( self, n: Optional[int] = None, real_only: bool = False, show_eigengap: bool = True, show_all_xticks: bool = True, legend_loc: Optional[str] = None, title: Optional[str] = None, figsize: Optional[Tuple[float, float]] = (5, 5), dpi: int = 100, save: Optional[Union[str, Path]] = None, marker: str = ".", **kwargs, ) -> None: """ Plot the top eigenvalues in real or complex plane. Parameters ---------- n Number of eigenvalues to show. If `None`, show all that have been computed. real_only Whether to plot only the real part of the spectrum. show_eigengap When `real_only=True`, this determines whether to show the inferred eigengap as a dotted line. show_all_xticks When `real_only=True`, this determines whether to show the indices of all eigenvalues on the x-axis. legend_loc Location parameter for the legend. title Title of the figure. %(plotting)s marker Marker symbol used, valid options can be found in :mod:`matplotlib.markers`. **kwargs Keyword arguments for :func:`matplotlib.pyplot.scatter`. Returns ------- %(just_plots)s """ eig = getattr(self, P.EIG.s) if eig is None: raise RuntimeError( f"Compute `.{P.EIG}` first as `.{F.COMPUTE.fmt(P.EIG)}()`.") if n is None: n = len(eig["D"]) elif n <= 0: raise ValueError(f"Expected `n` to be > 0, found `{n}`.") if real_only: fig = self._plot_real_spectrum( n, show_eigengap=show_eigengap, show_all_xticks=show_all_xticks, dpi=dpi, figsize=figsize, legend_loc=legend_loc, title=title, marker=marker, **kwargs, ) else: fig = self._plot_complex_spectrum( n, dpi=dpi, figsize=figsize, legend_loc=legend_loc, title=title, marker=marker, **kwargs, ) if save: save_fig(fig, save) fig.show() def _plot_complex_spectrum( self, n: int, dpi: int = 100, figsize: Optional[Tuple[float, float]] = (None, None), legend_loc: Optional[str] = None, title: Optional[str] = None, marker: str = ".", **kwargs, ): # define a function to make the data limits rectangular def adapt_range(min_, max_, range_): return ( min_ + (max_ - min_) / 2 - range_ / 2, min_ + (max_ - min_) / 2 + range_ / 2, ) eig = getattr(self, P.EIG.s) D, params = eig["D"][:n], eig["params"] # create fiture and axes fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize) # get the original data ranges lam_x, lam_y = D.real, D.imag x_min, x_max = np.min(lam_x), np.max(lam_x) y_min, y_max = np.min(lam_y), np.max(lam_y) x_range, y_range = x_max - x_min, y_max - y_min final_range = np.max([x_range, y_range]) + 0.05 x_min_, x_max_ = adapt_range(x_min, x_max, final_range) y_min_, y_max_ = adapt_range(y_min, y_max, final_range) # plot the data and the unit circle ax.scatter(D.real, D.imag, marker=marker, label="eigenvalue", **kwargs) t = np.linspace(0, 2 * np.pi, 500) x_circle, y_circle = np.sin(t), np.cos(t) ax.plot(x_circle, y_circle, "k-", label="unit circle") # set labels, ranges and legend ax.set_xlabel(r"Re($\lambda$)") ax.set_xlim(x_min_, x_max_) ax.set_ylabel(r"Im($\lambda$)") ax.set_ylim(y_min_, y_max_) key = "real part" if params["which"] == "LR" else "magnitude" if title is None: title = f"top {n} eigenvalues according to their {key}" ax.set_title(title) ax.legend(loc=legend_loc) return fig def _plot_real_spectrum( self, n: int, show_eigengap: bool = True, show_all_xticks: bool = True, dpi: int = 100, figsize: Optional[Tuple[float, float]] = None, legend_loc: Optional[str] = None, title: Optional[str] = None, marker: str = ".", **kwargs, ): eig = getattr(self, P.EIG.s) D, params = eig["D"][:n], eig["params"] D_real, D_imag = D.real, D.imag ixs = np.arange(len(D)) mask = D_imag == 0 # plot the top eigenvalues fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize) if np.any(mask): ax.scatter( ixs[mask], D_real[mask], marker=marker, label="real eigenvalue", **kwargs, ) if np.any(~mask): ax.scatter( ixs[~mask], D_real[~mask], marker=marker, label="complex eigenvalue", **kwargs, ) # add dashed line for the eigengap, ticks, labels, title and legend if show_eigengap and eig["eigengap"] < n: ax.axvline(eig["eigengap"], label="eigengap", ls="--", lw=1) ax.set_xlabel("index") if show_all_xticks: ax.set_xticks(np.arange(len(D))) else: ax.xaxis.set_major_locator(MultipleLocator(2.0)) ax.xaxis.set_major_formatter(FormatStrFormatter("%d")) ax.set_ylabel(r"Re($\lambda_i$)") key = "real part" if params["which"] == "LR" else "magnitude" if title is None: title = f"real part of top {n} eigenvalues according to their {key}" ax.set_title(title) ax.legend(loc=legend_loc) return fig
class Schur(VectorPlottable, Decomposable): """Class computing the Schur decomposition.""" __prop_metadata__ = [ Metadata( attr=A.SCHUR, prop=P.SCHUR, dtype=np.ndarray, compute_fmt=F.NO_FUNC, doc="Schur vectors.", ), Metadata(attr=A.SCHUR_MAT, prop=P.SCHUR_MAT, dtype=np.ndarray), Metadata(attr=A.EIG, prop=P.EIG, dtype=Mapping[str, Any]), Metadata(attr="_invalid_n_states", prop=P.NO_PROPERTY, dtype=np.ndarray), Metadata(attr="_gpcca", prop=P.NO_PROPERTY), ] @d.dedent @inject_docs(schur_vectors=P.SCHUR, schur_matrix=P.SCHUR_MAT, eigendec=P.EIG) def compute_schur( self, n_components: int = 10, initial_distribution: Optional[np.ndarray] = None, method: str = "krylov", which: str = "LR", alpha: float = 1, ): """ Compute the Schur decomposition. Parameters ---------- n_components Number of vectors to compute. initial_distribution Input probability distribution over all cells. If `None`, uniform is chosen. method Method for calculating the Schur vectors. Valid options are: `'krylov'` or `'brandts'`. For benefits of each method, see :class:`msmtools.analysis.dense.gpcca.GPCCA`. The former is an iterative procedure that computes a partial, sorted Schur decomposition for large, sparse matrices whereas the latter computes a full sorted Schur decomposition of a dense matrix. %(eigen)s Returns ------- None Nothing, but updates the following fields: - :paramref:`{schur_vectors}` - :paramref:`{schur_matrix}` - :paramref:`{eigendec}` """ if n_components < 2: raise ValueError( f"Number of components must be `>=2`, found `{n_components}`.") self._gpcca = _GPCCA(self.transition_matrix, eta=initial_distribution, z=which, method=method) start = logg.info("Computing Schur decomposition") try: self._gpcca._do_schur_helper(n_components) except ValueError: logg.warning( f"Using `{n_components}` components would split a block of complex conjugates. " f"Increasing `n_components` to `{n_components + 1}`") self._gpcca._do_schur_helper(n_components + 1) # make it available for pl setattr(self, A.SCHUR.s, self._gpcca.X) setattr(self, A.SCHUR_MAT.s, self._gpcca.R) self._invalid_n_states = np.array([ i for i in range(2, len(self._gpcca.eigenvalues)) if _check_conj_split(self._gpcca.eigenvalues[:i]) ]) if len(self._invalid_n_states): logg.info( f"When computing macrostates, choose a number of states NOT in `{list(self._invalid_n_states)}`" ) self._write_eig_to_adata( { "D": self._gpcca.eigenvalues, "eigengap": _eigengap(self._gpcca.eigenvalues, alpha), "params": { "which": which, "k": len(self._gpcca.eigenvalues), "alpha": alpha, }, }, start=start, extra_msg= f"\n `.{P.SCHUR}`\n `.{P.SCHUR_MAT}`\n Finish", ) plot_schur = _delegate(prop_name=P.SCHUR.s)(VectorPlottable._plot_vectors) @d.dedent def plot_schur_matrix( self, title: Optional[str] = "schur matrix", cmap: str = "viridis", figsize: Optional[Tuple[float, float]] = None, dpi: Optional[float] = 80, save: Optional[Union[str, Path]] = None, **kwargs, ): """ Plot the Schur matrix. Parameters ---------- title Title of the figure. cmap Colormap to use. %(plotting)s **kwargs Keyword arguments for :func:`seaborn.heatmap`. Returns ------- %(just_plots)s """ from seaborn import heatmap schur_matrix = getattr(self, P.SCHUR_MAT.s) if schur_matrix is None: raise RuntimeError( f"Compute Schur matrix first as `.{F.COMPUTE.fmt(P.SCHUR)}()`." ) fig, ax = plt.subplots( figsize=schur_matrix.shape if figsize is None else figsize, dpi=dpi) divider = make_axes_locatable( ax) # square=True make the colorbar a bit bigger cbar_ax = divider.append_axes("right", size="2%", pad=0.1) mask = np.zeros_like(schur_matrix, dtype=np.bool) mask[np.tril_indices_from(mask, k=-1)] = True mask[~np.isclose(schur_matrix, 0.0)] = False vmin, vmax = ( np.min(schur_matrix[~mask]), np.max(schur_matrix[~mask]), ) kwargs["fmt"] = kwargs.get("fmt", "0.2f") heatmap( schur_matrix, cmap=cmap, square=True, annot=True, vmin=vmin, vmax=vmax, cbar_ax=cbar_ax, cbar_kws={"ticks": np.linspace(vmin, vmax, 10)}, mask=mask, xticklabels=[], yticklabels=[], ax=ax, **kwargs, ) ax.set_title(title) if save is not None: save_fig(fig, path=save)
class GPCCA(BaseEstimator, Macrostates, Schur, Eigen): """ Generalized Perron Cluster Cluster Analysis :cite:`reuter:18` as implemented in `pyGPCCA <https://pygpcca.readthedocs.io/en/latest/>`_. Coarse-grains a discrete Markov chain into a set of macrostates and computes coarse-grained transition probabilities among the macrostates. Each macrostate corresponds to an area of the state space, i.e. to a subset of cells. The assignment is soft, i.e. each cell is assigned to every macrostate with a certain weight, where weights sum to one per cell. Macrostates are computed by maximizing the 'crispness' which can be thought of as a measure for minimal overlap between macrostates in a certain inner-product sense. Once the macrostates have been computed, we project the large transition matrix onto a coarse-grained transition matrix among the macrostates via a Galerkin projection. This projection is based on invariant subspaces of the original transition matrix which are obtained using the real Schur decomposition :cite:`reuter:18`. Parameters ---------- %(base_estimator.parameters)s """ # noqa: E501 __prop_metadata__ = [ Metadata( attr=A.COARSE_T, prop=P.COARSE_T, compute_fmt=F.NO_FUNC, plot_fmt=F.NO_FUNC, dtype=pd.DataFrame, doc="Coarse-grained transition matrix.", ), Metadata(attr=A.TERM_ABS_PROBS, prop=P.NO_PROPERTY, dtype=Lineage), Metadata(attr=A.COARSE_INIT_D, prop=P.COARSE_INIT_D, dtype=pd.Series), Metadata(attr=A.COARSE_STAT_D, prop=P.COARSE_STAT_D, dtype=pd.Series), ] def _read_from_adata(self) -> None: super()._read_from_adata() self._reconstruct_lineage( A.TERM_ABS_PROBS, self._term_abs_prob_key, ) @inject_docs( ms=P.MACRO, msp=P.MACRO_MEMBER, schur=P.SCHUR.s, coarse_T=P.COARSE_T, coarse_stat=P.COARSE_STAT_D, ) @d.dedent def compute_macrostates( self, n_states: Optional[Union[int, Tuple[int, int], List[int], Dict[str, int]]] = None, n_cells: Optional[int] = 30, use_min_chi: bool = False, cluster_key: str = None, en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, ): """ Compute the macrostates. Parameters ---------- n_states Number of macrostates. If `None`, use the `eigengap` heuristic. %(n_cells)s use_min_chi Whether to use :meth:`pygpcca.GPCCA.minChi` to calculate the number of macrostates. If `True`, ``n_states`` corresponds to a closed interval `[min, max]` inside of which the potentially optimal number of macrostates is searched. cluster_key If a key to cluster labels is given, names and colors of the states will be associated with the clusters. %(en_cutoff_p_thresh)s Returns ------- None Nothing, but updates the following fields: - :attr:`{msp}` - :attr:`{ms}` - :attr:`{schur}` - :attr:`{coarse_T}` - :attr:`{coarse_stat}` """ was_from_eigengap = False if use_min_chi: n_states = self._get_n_states_from_minchi(n_states) if n_states is None: if self._get(P.EIG) is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()` or `.compute_schur()`." ) was_from_eigengap = True n_states = self._get(P.EIG)["eigengap"] + 1 logg.info(f"Using `{n_states}` states based on eigengap") elif not isinstance(n_states, int): raise ValueError( f"Expected `n_states` to be an integer when `use_min_chi=False`, " f"found `{type(n_states).__name__!r}`.") if n_states <= 0: raise ValueError( f"Expected `n_states` to be positive or `None`, found `{n_states}`." ) n_states = self._check_states_validity(n_states) if n_states == 1: self._compute_one_macrostate( n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) return if self._gpcca is None: if not was_from_eigengap: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()`.") logg.warning( f"Number of states `{n_states}` was automatically determined by `eigengap` " "but no Schur decomposition was found. Computing with default parameters" ) # this cannot fail if splitting occurs # if it were to split, it's automatically increased in `compute_schur` self.compute_schur(n_states) # pre-computed X if self._gpcca._p_X.shape[1] < n_states: logg.warning( f"Requested more macrostates `{n_states}` than available " f"Schur vectors `{self._gpcca._p_X.shape[1]}`. Recomputing the decomposition" ) start = logg.info(f"Computing `{n_states}` macrostates") try: self._gpcca = self._gpcca.optimize(m=n_states) except ValueError as e: # this is the following case - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev. # in the try block, Schur decomposition with 5 vectors is computed, but it fails (no way of knowing) # so in this case, we increase it by 1 n_states += 1 logg.warning(f"{e}\nIncreasing `n_states` to `{n_states}`") self._gpcca = self._gpcca.optimize(m=n_states) self._set_macrostates( memberships=self._gpcca.memberships, n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) # cache the results and make sure we don't overwrite self._set(A.SCHUR, self._gpcca._p_X) self._set(A.SCHUR_MAT, self._gpcca._p_R) names = self._get(P.MACRO_MEMBER).names self._set( A.COARSE_T, pd.DataFrame( self._gpcca.coarse_grained_transition_matrix, index=names, columns=names, ), ) self._set( A.COARSE_INIT_D, pd.Series(self._gpcca.coarse_grained_input_distribution, index=names), ) # careful here, in case computing the stat. dist failed if self._gpcca.coarse_grained_stationary_probability is not None: self._set( A.COARSE_STAT_D, pd.Series( self._gpcca.coarse_grained_stationary_probability, index=names, ), ) logg.info( f"Adding `.{P.MACRO_MEMBER}`\n" f" `.{P.MACRO}`\n" f" `.{P.SCHUR}`\n" f" `.{P.COARSE_T}`\n" f" `.{P.COARSE_STAT_D}`\n" f" Finish", time=start, ) else: logg.warning("No stationary distribution found in GPCCA object") logg.info( f"Adding `.{P.MACRO_MEMBER}`\n" f" `.{P.MACRO}`\n" f" `.{P.SCHUR}`\n" f" `.{P.COARSE_T}`\n" f" Finish", time=start, ) @d.dedent @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS) def set_terminal_states_from_macrostates( self, names: Optional[Union[Sequence[str], Mapping[str, str], str]] = None, n_cells: int = 30, ): """ Manually select terminal states from macrostates. Parameters ---------- names Names of the macrostates to be marked as terminal. Multiple states can be combined using `','`, such as ``["Alpha, Beta", "Epsilon"]``. If a :class:`dict`, keys correspond to the names of the macrostates and the values to the new names. If `None`, select all macrostates. %(n_cells)s Returns ------- None Nothing, just updates the following fields: - :attr:`{fsp}` - :attr:`{fs}` """ if not isinstance(n_cells, int): raise TypeError( f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__}`." ) if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.MACRO_MEMBER) if probs is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") rename = True if names is None: names = probs.names rename = False if isinstance(names, str): names = [names] rename = False if not isinstance(names, dict): names = {n: n for n in names} rename = False if not len(names): raise ValueError("No macrostates have been selected.") if not all(isinstance(old, str) for old in names.keys()): raise TypeError("Not all new names are strings.") if not all(isinstance(new, (str, int)) for new in names.values()): raise TypeError( "Not all macrostates names are strings or integers.") # this also checks that the names are correct before renaming macrostates_probs = probs[list(names.keys())] # we do this also here because if `rename_terminal_states` fails # invalid states would've been written to this object and nothing to adata new_names = {k: str(v) for k, v in names.items()} names_after_renaming = [new_names.get(n, n) for n in probs.names] if len(set(names_after_renaming)) != probs.shape[1]: raise ValueError( f"After renaming, the names will not be unique: `{names_after_renaming}`." ) if probs.shape[1] == 1: self._set(A.TERM, self._create_states(probs, n_cells=n_cells)) self._set(A.TERM_COLORS, self._get(A.MACRO_COLORS)) self._set( A.TERM_PROBS, pd.Series(probs.X.squeeze() / probs.X.max(), index=self.adata.obs_names), ) self._set(A.TERM_ABS_PROBS, probs) if rename: # access lineage renames join states, e.g. 'Alpha, Beta' becomes 'Alpha or Beta' + whitespace stripping self.rename_terminal_states( dict(zip(self._get(P.TERM).cat.categories, names.values()))) self._write_terminal_states() return # compute the aggregated probability of being a initial/terminal state (no matter which) scaled_probs = macrostates_probs.copy() scaled_probs /= scaled_probs.max(0) self._set(A.TERM, self._create_states(macrostates_probs, n_cells=n_cells)) self._set(A.TERM_PROBS, pd.Series(scaled_probs.X.max(1), index=self.adata.obs_names)) self._set( A.TERM_COLORS, macrostates_probs[list(self._get(P.TERM).cat.categories)].colors, ) self._set(A.TERM_ABS_PROBS, scaled_probs) if rename: self.rename_terminal_states( dict(zip(self._get(P.TERM).cat.categories, names.values()))) self._write_terminal_states() @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS) @d.dedent def compute_terminal_states( self, method: str = "stability", n_cells: int = 30, alpha: Optional[float] = 1, stability_threshold: float = 0.96, n_states: Optional[int] = None, ): """ Automatically select terminal states from macrostates. Parameters ---------- method One of following: - `'eigengap'` - select the number of states based on the `eigengap` of the transition matrix. - `'eigengap_coarse'` - select the number of states based on the `eigengap` of the diagonal of the coarse-grained transition matrix. - `'top_n'` - select top ``n_states`` based on the probability of the diagonal of the coarse-grained transition matrix. - `'stability'` - select states which have a stability index >= ``stability_threshold``. The stability index is given by the diagonal elements of the coarse-grained transition matrix. %(n_cells)s alpha Weight given to the deviation of an eigenvalue from one. Used when ``method='eigengap'`` or ``method='eigengap_coarse'``. stability_threshold Threshold used when ``method='stability'``. n_states Numer of states used when ``method='top_n'``. Returns ------- None Nothing, just updates the following fields: - :attr:`{fsp}` - :attr:`{fs}` """ if len(self._get(P.MACRO).cat.categories) == 1: logg.warning( "Found only one macrostate. Making it the single main state") self.set_terminal_states_from_macrostates(None, n_cells=n_cells) return coarse_T = self._get(P.COARSE_T) if method == "eigengap": if self._get(P.EIG) is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()`." ) n_states = _eigengap(self._get(P.EIG)["D"], alpha=alpha) + 1 elif method == "eigengap_coarse": if coarse_T is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") n_states = _eigengap(np.sort(np.diag(coarse_T)[::-1]), alpha=alpha) elif method == "top_n": if n_states is None: raise ValueError( "Argument `n_states` must be != `None` for `method='top_n'`." ) elif n_states <= 0: raise ValueError( f"Expected `n_states` to be positive, found `{n_states}`.") elif method == "stability": if stability_threshold is None: raise ValueError( "Argument `stability_threshold` must be != `None` for `method='stability'`." ) self_probs = pd.Series(np.diag(coarse_T), index=coarse_T.columns) names = self_probs[self_probs.values >= stability_threshold].index self.set_terminal_states_from_macrostates(names, n_cells=n_cells) return else: raise ValueError( f"Invalid method `{method!r}`. Valid options are `'eigengap', 'eigengap_coarse', " f"'top_n' and 'min_self_prob'`.") names = coarse_T.columns[np.argsort(np.diag(coarse_T))][-n_states:] self.set_terminal_states_from_macrostates(names, n_cells=n_cells) def compute_gdpt(self, n_components: int = 10, key_added: str = "gdpt_pseudotime", **kwargs): """ Compute generalized Diffusion pseudotime from :cite:`haghverdi:16` using the real Schur decomposition. Parameters ---------- n_components Number of real Schur vectors to consider. key_added Key in :attr:`adata` ``.obs`` where to save the pseudotime. kwargs Keyword arguments for :meth:`cellrank.tl.GPCCA.compute_schur` if Schur decomposition is not found. Returns ------- None Nothing, just updates :attr:`adata` ``.obs[key_added]`` with the computed pseudotime. """ def _get_dpt_row(e_vals: np.ndarray, e_vecs: np.ndarray, i: int): row = sum( (np.abs(e_vals[eval_ix]) / (1 - np.abs(e_vals[eval_ix])) * (e_vecs[i, eval_ix] - e_vecs[:, eval_ix]))**2 # account for float32 precision for eval_ix in range(0, e_vals.size) if np.abs(e_vals[eval_ix]) < 0.9994) return np.sqrt(row) if "iroot" not in self.adata.uns.keys(): raise KeyError("Key `'iroot'` not found in `adata.uns`.") iroot = self.adata.uns["iroot"] if isinstance(iroot, str): iroot = np.where(self.adata.obs_names == iroot)[0] if not len(iroot): raise ValueError( f"Unable to find cell with name `{self.adata.uns['iroot']!r}` in `adata.obs_names`." ) iroot = iroot[0] if n_components < 2: raise ValueError( f"Expected number of components >= 2, found `{n_components}`.") if self._get(P.SCHUR) is None: logg.warning("No Schur decomposition found. Computing") self.compute_schur(n_components, **kwargs) elif self._get(P.SCHUR_MAT).shape[1] < n_components: logg.warning( f"Requested `{n_components}` components, but only `{self._get(P.SCHUR_MAT).shape[1]}` were found. " f"Recomputing using default values") self.compute_schur(n_components) else: logg.debug("Using cached Schur decomposition") start = logg.info( f"Computing Generalized Diffusion Pseudotime using `n_components={n_components}`" ) Q, eigenvalues = ( self._get(P.SCHUR), self._get(P.EIG)["D"], ) # may have to remove some values if too many converged Q, eigenvalues = Q[:, :n_components], eigenvalues[:n_components] D = _get_dpt_row(eigenvalues, Q, i=iroot) pseudotime = D / np.max(D[np.isfinite(D)]) self.adata.obs[key_added] = pseudotime logg.info(f"Adding `{key_added!r}` to `adata.obs`\n Finish", time=start) @d.dedent def plot_coarse_T( self, show_stationary_dist: bool = True, show_initial_dist: bool = False, cmap: Union[str, mcolors.ListedColormap] = "viridis", xtick_rotation: float = 45, annotate: bool = True, show_cbar: bool = True, title: Optional[str] = None, figsize: Tuple[float, float] = (8, 8), dpi: int = 80, save: Optional[Union[Path, str]] = None, text_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs, ) -> None: """ Plot the coarse-grained transition matrix between macrostates. Parameters ---------- show_stationary_dist Whether to show the stationary distribution, if present. show_initial_dist Whether to show the initial distribution. cmap Colormap to use. xtick_rotation Rotation of ticks on the x-axis. annotate Whether to display the text on each cell. show_cbar Whether to show colorbar. title Title of the figure. %(plotting)s text_kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. kwargs Keyword arguments for :func:`matplotlib.pyplot.imshow`. Returns ------- %(just_plots)s """ def stylize_dist(ax, data: np.ndarray, xticks_labels: Union[List[str], Tuple[str]] = ()): _ = ax.imshow(data, aspect="auto", cmap=cmap, norm=norm) for spine in ax.spines.values(): spine.set_visible(False) if xticks_labels is not None: ax.set_xticks(np.arange(data.shape[1])) ax.set_xticklabels(xticks_labels) plt.setp( ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode="anchor", ) else: ax.set_xticks([]) ax.tick_params(which="both", top=False, right=False, bottom=False, left=False) ax.set_yticks([]) def annotate_heatmap(im, valfmt: str = "{x:.2f}"): # modified from matplotlib's site data = im.get_array() kw = {"ha": "center", "va": "center"} kw.update(**text_kwargs) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = mpl.ticker.StrMethodFormatter(valfmt) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update( color=_get_black_or_white(im.norm(data[i, j]), cmap)) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"): if ax is None: return if isinstance(valfmt, str): valfmt = mpl.ticker.StrMethodFormatter(valfmt) kw = {"ha": "center", "va": "center"} kw.update(**text_kwargs) for i, val in enumerate(data): kw.update(color=_get_black_or_white(im.norm(val), cmap)) ax.text( i, 0, valfmt(val, None), **kw, ) coarse_T = self._get(P.COARSE_T) coarse_stat_d = self._get(P.COARSE_STAT_D) coarse_init_d = self._get(P.COARSE_INIT_D) if coarse_T is None: raise RuntimeError( "Compute coarse-grained transition matrix first as `.compute_macrostates()` with `n_states > 1`." ) if show_stationary_dist and coarse_stat_d is None: logg.warning("Coarse stationary distribution is `None`, ignoring") show_stationary_dist = False if show_initial_dist and coarse_init_d is None: logg.warning("Coarse initial distribution is `None`, ignoring") show_initial_dist = False hrs, wrs = [1], [1] if show_stationary_dist: hrs += [0.05] if show_initial_dist: hrs += [0.05] if show_cbar: wrs += [0.025] dont_show_dist = not show_initial_dist and not show_stationary_dist fig = plt.figure(constrained_layout=False, figsize=figsize, dpi=dpi) gs = plt.GridSpec( 1 + show_stationary_dist + show_initial_dist, 1 + show_cbar, height_ratios=hrs, width_ratios=wrs, wspace=0.05, hspace=0.05, ) if isinstance(cmap, str): cmap = plt.get_cmap(cmap) ax = fig.add_subplot(gs[0, 0]) cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None init_ax, stat_ax = None, None labels = list(self.coarse_T.columns) tmp = coarse_T if show_initial_dist: tmp = np.c_[tmp, coarse_stat_d] if show_initial_dist: tmp = np.c_[tmp, coarse_init_d] minn, maxx = np.nanmin(tmp), np.nanmax(tmp) norm = mpl.colors.Normalize(vmin=minn, vmax=maxx) if show_stationary_dist: stat_ax = fig.add_subplot(gs[1, 0]) stylize_dist( stat_ax, np.array(coarse_stat_d).reshape(1, -1), xticks_labels=labels if not show_initial_dist else None, ) stat_ax.yaxis.set_label_position("right") stat_ax.set_ylabel("stationary dist", rotation=0, ha="left", va="center") if show_initial_dist: init_ax = fig.add_subplot(gs[show_stationary_dist + show_initial_dist, 0]) stylize_dist(init_ax, np.array(coarse_init_d).reshape(1, -1), xticks_labels=labels) init_ax.yaxis.set_label_position("right") init_ax.set_ylabel("initial dist", rotation=0, ha="left", va="center") im = ax.imshow(coarse_T, aspect="auto", cmap=cmap, norm=norm, **kwargs) ax.set_title( "coarse-grained transition matrix" if title is None else title) if cax is not None: _ = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, ticks=np.linspace(minn, maxx, 10), format="%0.3f", ) ax.set_yticks(np.arange(coarse_T.shape[0])) ax.set_yticklabels(labels) ax.tick_params( top=False, bottom=dont_show_dist, labeltop=False, labelbottom=dont_show_dist, ) for spine in ax.spines.values(): spine.set_visible(False) if dont_show_dist: ax.set_xticks(np.arange(coarse_T.shape[1])) ax.set_xticklabels(labels) plt.setp( ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode="anchor", ) else: ax.set_xticks([]) ax.set_yticks(np.arange(coarse_T.shape[0] + 1) - 0.5, minor=True) ax.tick_params(which="minor", bottom=dont_show_dist, left=False, top=False) if annotate: annotate_heatmap(im) if show_stationary_dist: annotate_dist_ax(stat_ax, coarse_stat_d.values) if show_initial_dist: annotate_dist_ax(init_ax, coarse_init_d) if save: save_fig(fig, save) @d.dedent def plot_macrostate_composition( self, key: str, width: float = 0.8, title: Optional[str] = None, labelrot: float = 45, legend_loc: Optional[str] = "upper right out", figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, show: bool = True, ) -> Optional[Axes]: """ Plot stacked histogram of macrostates over categorical annotations. Parameters ---------- %(adata)s key Key from :attr:`adata` ``.obs`` containing categorical annotations. width Bar width in `[0, 1]`. title Title of the figure. If `None`, create one automatically. labelrot Rotation of labels on x-axis. legend_loc Position of the legend. If `None`, don't show legend. %(plotting)s show If `False`, return :class:`matplotlib.pyplot.Axes`. Returns ------- :class:`matplotlib.pyplot.Axes` The axis object if ``show=False``. %(just_plots)s """ from cellrank.pl._utils import _position_legend macrostates = self._get(P.MACRO) if macrostates is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") if key not in self.adata.obs: raise KeyError(f"Key `{key}` not found in `adata.obs`.") if not is_categorical_dtype(self.adata.obs[key]): raise TypeError( f"Expected `adata.obs[{key!r}]` to be `categorical`, " f"found `{infer_dtype(self.adata.obs[key])}`.") mask = ~macrostates.isnull() df = (pd.DataFrame({ "macrostates": macrostates, key: self.adata.obs[key] })[mask].groupby([key, "macrostates"]).size()) try: cats_colors = self.adata.uns[f"{key}_colors"] except KeyError: cats_colors = _create_categorical_colors( len(self.adata.obs[key].cat.categories)) cat_color_mapper = dict( zip(self.adata.obs[key].cat.categories, cats_colors)) x_indices = np.arange(len(macrostates.cat.categories)) bottom = np.zeros_like(x_indices, dtype=np.float32) width = min(1, max(0, width)) fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True) for cat, color in cat_color_mapper.items(): frequencies = df.loc[cat] # do not add to legend if category is missing if np.sum(frequencies) > 0: ax.bar( x_indices, frequencies, width, label=cat, color=color, bottom=bottom, ec="black", lw=0.5, ) bottom += np.array(frequencies) ax.set_xticks(x_indices) ax.set_xticklabels( # assuming at least 1 category frequencies.index, rotation=labelrot, ha="center" if labelrot in (0, 90) else "right", ) y_max = bottom.max() ax.set_ylim([0, y_max + 0.05 * y_max]) ax.set_yticks(np.linspace(0, y_max, 5)) ax.margins(0.05) ax.set_xlabel("macrostate") ax.set_ylabel("frequency") if title is None: title = f"distribution over {key}" ax.set_title(title) if legend_loc not in (None, "none"): _position_legend(ax, legend_loc=legend_loc) if save is not None: save_fig(fig, save) if not show: return ax def _compute_one_macrostate( self, n_cells: int, cluster_key: Optional[str], en_cutoff: Optional[float], p_thresh: float, ) -> None: start = logg.warning( "For 1 macrostate, stationary distribution is computed") eig = self._get(P.EIG) if (eig is not None and "stationary_dist" in eig and eig["params"]["which"] == "LR"): stationary_dist = eig["stationary_dist"] else: self.compute_eigendecomposition(only_evals=False, which="LR") stationary_dist = self._get(P.EIG)["stationary_dist"] self._set_macrostates( memberships=stationary_dist[:, None], n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) self._set( A.MACRO_MEMBER, Lineage( stationary_dist, names=list(self._get(A.MACRO).cat.categories), colors=self._get(A.MACRO_COLORS), ), ) # reset all the things for key in ( A.ABS_PROBS, A.PRIME_DEG, A.SCHUR, A.SCHUR_MAT, A.COARSE_T, A.COARSE_STAT_D, A.COARSE_STAT_D, ): self._set(key.s, None) logg.info( f"Adding `.{P.MACRO_MEMBER}`\n `.{P.MACRO}`\n Finish", time=start, ) def _get_n_states_from_minchi( self, n_states: Union[Tuple[int, int], List[int], Dict[str, int]]) -> int: if self._gpcca is None: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`." ) if not isinstance(n_states, (dict, tuple, list)): raise TypeError( f"Expected `n_states` to be either `dict`, `tuple` or a `list`, " f"found `{type(n_states).__name__}`.") if len(n_states) != 2: raise ValueError( f"Expected `n_states` to be of size `2`, found `{len(n_states)}`." ) if isinstance(n_states, dict): if "min" not in n_states or "max" not in n_states: raise KeyError( f"Expected the dictionary to have `'min'` and `'max'` keys, " f"found `{tuple(n_states.keys())}`.") minn, maxx = n_states["min"], n_states["max"] else: minn, maxx = n_states if minn > maxx: logg.debug( f"Swapping minimum and maximum because `{minn}` > `{maxx}`") minn, maxx = maxx, minn if minn <= 1: raise ValueError(f"Minimum value must be > `1`, found `{minn}`.") elif minn == 2: logg.warning( "In most cases, 2 clusters will always be optimal. " "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`" ) minn = 3 if minn >= maxx: maxx = minn + 1 logg.debug( f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`" ) logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`") return int( np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn, maxx))]) @d.dedent def _set_macrostates( self, memberships: np.ndarray, n_cells: Optional[int] = 30, cluster_key: str = "clusters", en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, check_row_sums: bool = True, ) -> None: """ Map fuzzy clustering to pre-computed annotations to get names and colors. Given the fuzzy clustering, we would like to select the most likely cells from each state and use these to give each state a name and a color by comparing with pre-computed, categorical cluster annotations. Parameters ---------- memberships Fuzzy clustering. %(n_cells)s cluster_key Key from :attr:`adata` ``.obs`` to get reference cluster annotations. en_cutoff Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost evenly across the reference clusters. p_thresh Only used to detect cell cycle stages. These have to be present in :attr:`adata` ``.obs`` as `'G2M_score'` and `'S_score'`. check_row_sums Check whether rows in `memberships` sum to `1`. Returns ------- None Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors. Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state. """ if n_cells is None: logg.debug("Setting the macrostates using macrostate assignment") # fmt: off max_assignment = np.argmax(memberships, axis=1) _macro_assignment = pd.Series(index=self.adata.obs_names, data=max_assignment, dtype="category") # sometimes, the assignment can have a missing category and the Lineage creation therefore fails # keep it as ints when `n_cells != None` _macro_assignment = _macro_assignment.cat.set_categories( list(range(memberships.shape[1]))) macrostates = _macro_assignment.astype(str).astype( "category").copy() not_enough_cells = [] # fmt: on else: logg.debug("Setting the macrostates using macrostates memberships") # select the most likely cells from each macrostate macrostates, not_enough_cells = self._create_states( memberships, n_cells=n_cells, check_row_sums=check_row_sums, return_not_enough_cells=True, ) not_enough_cells = not_enough_cells.astype("str") # _set_categorical_labels creates the names, we still need to remap the group names orig_cats = macrostates.cat.categories self._set_categorical_labels( attr_key=A.MACRO.v, color_key=A.MACRO_COLORS.v, pretty_attr_key=P.MACRO.v, add_to_existing_error_msg= "Compute macrostates first as `.compute_macrostates()`.", categories=macrostates, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, ) name_mapper = dict(zip(orig_cats, self._get(P.MACRO).cat.categories)) _print_insufficient_number_of_cells( [name_mapper.get(n, n) for n in not_enough_cells], n_cells) logg.debug( "Setting macrostates memberships based on GPCCA membership vectors" ) self._set( A.MACRO_MEMBER, Lineage( memberships, names=list(macrostates.cat.categories), colors=self._get(A.MACRO_COLORS), ), ) def _create_states( self, probs: Union[np.ndarray, Lineage], n_cells: int, check_row_sums: bool = False, return_not_enough_cells: bool = False, ) -> pd.Series: if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") a_discrete, not_enough_cells = _fuzzy_to_discrete( a_fuzzy=probs, n_most_likely=n_cells, remove_overlap=False, raise_threshold=0.2, check_row_sums=check_row_sums, ) states = _series_from_one_hot_matrix( membership=a_discrete, index=self.adata.obs_names, names=probs.names if isinstance(probs, Lineage) else None, ) return (states, not_enough_cells) if return_not_enough_cells else states def _check_states_validity(self, n_states: int) -> int: if self._invalid_n_states is not None and n_states in self._invalid_n_states: logg.warning( f"Unable to compute macrostates with `n_states={n_states}` because it will " f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`" ) n_states += 1 # cannot force recomputation of the Schur decomposition assert n_states not in self._invalid_n_states, "Sanity check failed." return n_states def _fit_terminal_states( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, method: str = "krylov", **kwargs, ) -> None: if n_lineages is None or n_lineages == 1: self.compute_eigendecomposition() if n_lineages is None: n_lineages = self.eigendecomposition["eigengap"] + 1 if n_lineages > 1: self.compute_schur(n_lineages, method=method) try: self.compute_macrostates(n_states=n_lineages, cluster_key=cluster_key, **kwargs) except ValueError: logg.warning( f"Computing `{n_lineages}` macrostates cuts through a block of complex conjugates. " f"Increasing `n_lineages` to {n_lineages + 1}") self.compute_macrostates(n_states=n_lineages + 1, cluster_key=cluster_key, **kwargs) fs_kwargs = { "n_cells": kwargs["n_cells"] } if "n_cells" in kwargs else {} if n_lineages is None: self.compute_terminal_states(method="eigengap", **fs_kwargs) else: self.set_terminal_states_from_macrostates(**fs_kwargs) @d.dedent # because of fit @d.dedent @inject_docs( ms=P.MACRO, msp=P.MACRO_MEMBER, fs=P.TERM, fsp=P.TERM_PROBS, ap=P.ABS_PROBS, pd=P.PRIME_DEG, ) def fit( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, keys: Optional[Sequence[str]] = None, method: str = "krylov", compute_absorption_probabilities: bool = True, **kwargs, ): """ Run the pipeline, computing the macrostates, %(initial_or_terminal)s states \ and optionally the absorption probabilities. It is equivalent to running:: if n_lineages is None or n_lineages == 1: compute_eigendecomposition(...) # get the stationary distribution if n_lineages > 1: compute_schur(...) compute_macrostates(...) if n_lineages is None: compute_terminal_states(...) else: set_terminal_states_from_macrostates(...) if compute_absorption_probabilities: compute_absorption_probabilities(...) Parameters ---------- %(fit)s method Method to use when computing the Schur decomposition. Valid options are: `'krylov'` or `'brandts'`. compute_absorption_probabilities Whether to compute the absorption probabilities or only the %(initial_or_terminal)s states. kwargs Keyword arguments for :meth:`cellrank.tl.estimators.GPCCA.compute_macrostates`. Returns ------- None Nothing, just makes available the following fields: - :attr:`{msp}` - :attr:`{ms}` - :attr:`{fsp}` - :attr:`{fs}` - :attr:`{ap}` - :attr:`{pd}` """ super().fit( n_lineages=n_lineages, cluster_key=cluster_key, keys=keys, method=method, compute_absorption_probabilities=compute_absorption_probabilities, **kwargs, ) @d.dedent def _compute_initial_states(self, n_states: int = 1, n_cells: int = 30) -> None: """ Compute initial states from macrostates. Parameters ---------- n_states Number of initial states. %(n_cells)s Returns ------- %(set_initial_states_from_macrostates.returns)s """ if n_states <= 0: raise ValueError( f"Expected `n_states` to be positive, found `{n_states}`.") if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.MACRO_MEMBER) if probs is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") if n_states > probs.shape[1]: raise ValueError( f"Requested `{n_states}` initial states, but only `{probs.shape[1]}` macrostates have been computed." ) if probs.shape[1] == 1: self._set_initial_states_from_macrostates(n_cells=n_cells) return stat_dist = self._get(P.COARSE_STAT_D) if stat_dist is None: raise RuntimeError( "No coarse-grained stationary distribution found.") self._set_initial_states_from_macrostates( stat_dist[np.argsort(stat_dist)][:n_states].index, n_cells=n_cells) @d.get_sections(base="set_initial_states_from_macrostates", sections=["Returns"]) @d.dedent @inject_docs(key=TermStatesKey.BACKWARD.s, probs_key=_probs(TermStatesKey.BACKWARD.s)) def _set_initial_states_from_macrostates( self, names: Optional[Union[Iterable[str], str]] = None, n_cells: int = 30, ) -> None: """ Manually select initial states from macrostates. Note that no check is performed to ensure initial and terminal states are distinct. Parameters ---------- names Names of the macrostates to be marked as initial states. Multiple states can be combined using `','`, such as `["Alpha, Beta", "Epsilon"]`. %(n_cells)s Returns ------- None Nothing, just writes to :attr:`adata`: - ``.obs[{key!r}]`` - probability of being an initial state. - ``.obs[{probs_key!r}]`` - top ``n_cells`` from each initial state. """ if not isinstance(n_cells, int): raise TypeError( f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__!r}`." ) if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.MACRO_MEMBER) if probs is None: raise RuntimeError( "Compute macrostates first as `.compute_macrostates()`.") elif probs.shape[1] == 1: categorical = self._create_states(probs, n_cells=n_cells) scaled = probs / probs.max() else: if names is None: names = probs.names if isinstance(names, str): names = [names] probs = probs[list(names)] categorical = self._create_states(probs, n_cells=n_cells) probs /= probs.max(0) # compute the aggregated probability of being a initial/terminal state (no matter which) scaled = probs.X.max(1) self._write_initial_states(membership=probs, probs=scaled, cats=categorical) def _write_initial_states(self, membership: Lineage, probs: pd.Series, cats: pd.Series, time=None) -> None: key = TermStatesKey.BACKWARD.s self.adata.obs[key] = cats self.adata.obs[_probs(key)] = probs self.adata.uns[_colors(key)] = membership.colors self.adata.uns[_lin_names(key)] = membership.names logg.info( f"Adding `adata.obs[{_probs(key)!r}]`\n `adata.obs[{key!r}]`\n", time=time, ) def _write_terminal_states(self, time=None) -> None: super()._write_terminal_states(time=time) term_abs_probs = self._get(A.TERM_ABS_PROBS) if term_abs_probs is None: # possibly remove previous value if it's inconsistent term_abs_probs = self.adata.obsm.get(self._term_abs_prob_key, None) if term_abs_probs is not None: new = list(self._get(P.TERM).cat.categories) old = list(term_abs_probs.names) if term_abs_probs.shape[1] == len(new) and new == old: self.adata.obsm[self._term_abs_prob_key] = term_abs_probs else: logg.warning( f"Removing previously computed `adata.obsm[{self._term_abs_prob_key!r}]` because the " f"names mismatch `{new}` (new), `{old}` (old).") self._set(A.TERM_ABS_PROBS, None) self.adata.obsm.pop(self._term_abs_prob_key, None)
class GPCCA(BaseEstimator, MetaStates, Schur, Eigen): """ Generalized Perron Cluster Cluster Analysis [GPCCA18]_. Parameters ---------- %(base_estimator.parameters)s """ __prop_metadata__ = [ Metadata( attr=A.COARSE_T, prop=P.COARSE_T, compute_fmt=F.NO_FUNC, plot_fmt=F.NO_FUNC, dtype=pd.DataFrame, doc="Coarse-grained transition matrix.", ), Metadata(attr=A.FIN_ABS_PROBS, prop=P.NO_PROPERTY, dtype=Lineage), Metadata(attr=A.COARSE_INIT_D, prop=P.COARSE_INIT_D, dtype=pd.Series), Metadata(attr=A.COARSE_STAT_D, prop=P.COARSE_STAT_D, dtype=pd.Series), ] def _read_from_adata(self) -> None: super()._read_from_adata() self._reconstruct_lineage( A.FIN_ABS_PROBS, self._fin_abs_prob_key, ) @inject_docs( ms=P.META, msp=P.META_PROBS, schur=P.SCHUR.s, coarse_T=P.COARSE_T, coarse_stat=P.COARSE_STAT_D, ) @d.dedent def compute_metastable_states( self, n_states: Optional[Union[int, Tuple[int, int], List[int], Dict[str, int]]] = None, n_cells: Optional[int] = 30, use_min_chi: bool = False, cluster_key: str = None, en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, ): """ Compute the metastable states. Parameters ---------- n_states Number of metastable states. If `None`, use the `eigengap` heuristic. %(n_cells)s use_min_chi Whether to use :meth:`msmtools.analysis.dense.gpcca.GPCCA.minChi` to calculate the number of metastable states. If `True`, ``n_states`` corresponds to an interval `[min, max]` inside of which the potentially optimal number of metastable states is searched. cluster_key If a key to cluster labels is given, names and colors of the states will be associated with the clusters. en_cutoff If ``cluster_key`` is given, this parameter determines when an approximate recurrent class will be labelled as *'Unknown'*, based on the entropy of the distribution of cells over transcriptomic clusters. p_thresh If cell cycle scores were provided, a *Wilcoxon rank-sum test* is conducted to identify cell-cycle driven start- or endpoints. If the test returns a positive statistic and a p-value smaller than ``p_thresh``, a warning will be issued. Returns ------- None Nothing, but updates the following fields: - :paramref:`{msp}` - :paramref:`{ms}` - :paramref:`{schur}` - :paramref:`{coarse_T}` - :paramref:`{coarse_stat}` """ was_from_eigengap = False if use_min_chi: n_states = self._get_n_states_from_minchi(n_states) if n_states is None: if self._get(P.EIG) is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()` or `.compute_schur()`." ) was_from_eigengap = True n_states = self._get(P.EIG)["eigengap"] + 1 logg.info(f"Using `{n_states}` states based on eigengap") elif not isinstance(n_states, int): raise ValueError( f"Expected `n_states` to be an integer when `use_min_chi=False`, " f"found `{type(n_states).__name__!r}`.") if n_states <= 0: raise ValueError( f"Expected `n_states` to be positive or `None`, found `{n_states}`." ) n_states = self._check_states_validity(n_states) if n_states == 1: self._compute_meta_for_one_state( n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) return if self._gpcca is None: if not was_from_eigengap: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()`.") logg.warning( f"Number of states `{n_states}` was automatically determined by `eigengap` " "but no Schur decomposition was found. Computing with default parameters" ) # this cannot fail if splitting occurs # if it were to split, it's automatically increased in `compute_schur` self.compute_schur(n_states + 1) if self._gpcca.X.shape[1] < n_states: logg.warning( f"Requested more metastable states `{n_states}` than available " f"Schur vectors `{self._gpcca.X.shape[1]}`. Recomputing the decomposition" ) start = logg.info(f"Computing `{n_states}` metastable states") try: self._gpcca = self._gpcca.optimize(m=n_states) except ValueError as e: # this is the following cage - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev. # in the try block, schur decomposition with 5 vectors is computed, but it fails (no way of knowing) # so in this case, we increate it by 1 n_states += 1 logg.warning(f"{e}\nIncreasing `n_states` to `{n_states}`") self._gpcca = self._gpcca.optimize(m=n_states) self._set_meta_states( memberships=self._gpcca.memberships, n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) # cache the results and make sure we don't overwrite self._set(A.SCHUR, self._gpcca.X) self._set(A.SCHUR_MAT, self._gpcca.R) names = self._get(P.META_PROBS).names self._set( A.COARSE_T, pd.DataFrame( self._gpcca.coarse_grained_transition_matrix, index=names, columns=names, ), ) self._set( A.COARSE_INIT_D, pd.Series(self._gpcca.coarse_grained_input_distribution, index=names), ) # careful here, in case computing the stat. dist failed if self._gpcca.coarse_grained_stationary_probability is not None: self._set( A.COARSE_STAT_D, pd.Series( self._gpcca.coarse_grained_stationary_probability, index=names, ), ) logg.info( f"Adding `.{P.META_PROBS}`\n" f" `.{P.META}`\n" f" `.{P.SCHUR}`\n" f" `.{P.COARSE_T}`\n" f" `.{P.COARSE_STAT_D}`\n" f" Finish", time=start, ) else: logg.warning("No stationary distribution found in GPCCA object") logg.info( f"Adding `.{P.META_PROBS}`\n" f" `.{P.META}`\n" f" `.{P.SCHUR}`\n" f" `.{P.COARSE_T}`\n" f" Finish", time=start, ) @d.dedent @inject_docs(fs=P.FIN, fsp=P.FIN_PROBS) def set_final_states_from_metastable_states( self, names: Optional[Union[Iterable[str], str]] = None, n_cells: int = 30, ): """ Manually select the main states from the metastable states. Parameters ---------- names Names of the main states. Multiple states can be combined using `','`, such as `['Alpha, Beta', 'Epsilon']`. %(n_cells)s Returns ------- None Nothing, just updates the following fields: - :paramref:`{fsp}` - :paramref:`{fs}` """ if not isinstance(n_cells, int): raise TypeError( f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__}`." ) if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") probs = self._get(P.META_PROBS) if self._get(P.META_PROBS) is None: raise RuntimeError( "Compute metastable_states first as `.compute_metastable_states()`." ) elif probs.shape[1] == 1: self._set(A.FIN, self._create_states(probs, n_cells=n_cells)) self._set(A.FIN_COLORS, self._get(A.META_COLORS)) self._set(A.FIN_PROBS, probs / probs.max()) self._set(A.FIN_ABS_PROBS, probs) self._write_final_states() return if names is None: names = probs.names if isinstance(names, str): names = [names] meta_states_probs = probs[list(names)] # compute the aggregated probability of being a initial/terminal state (no matter which) scaled_probs = meta_states_probs[[ n for n in meta_states_probs.names if n != "rest" ]].copy() scaled_probs /= scaled_probs.max(0) self._set(A.FIN, self._create_states(meta_states_probs, n_cells)) self._set(A.FIN_PROBS, pd.Series(scaled_probs.X.max(1), index=self.adata.obs_names)) self._set( A.FIN_COLORS, meta_states_probs[list(self._get(P.FIN).cat.categories)].colors, ) self._set(A.FIN_ABS_PROBS, scaled_probs) self._write_final_states() @inject_docs(fs=P.FIN, fsp=P.FIN_PROBS) @d.dedent def compute_final_states( self, method: str = "eigengap", n_cells: int = 30, alpha: Optional[float] = 1, min_self_prob: Optional[float] = None, n_final_states: Optional[int] = None, ): """ Automatically select the main states from metastable states. Parameters ---------- method One of following: - `'eigengap'` - select the number of states based on the eigengap of the transition matrix. - `'eigengap_coarse'` - select the number of states based on the eigengap of the diagonal of the coarse-grained transition matrix. - `'top_n'` - select top ``n_final_states`` based on the probability of the diagonal \ of the coarse-grained transition matrix. - `'min_self_prob'` - select states which have the given minimum probability of the diagonal of the coarse-grained transition matrix. %(n_cells)s alpha Weight given to the deviation of an eigenvalue from one. Used when ``method='eigengap'`` or ``method='eigengap_coarse'``. min_self_prob Used when ``method='min_self_prob'``. n_final_states Used when ``method='top_n'``. Returns ------- None Nothing, just updates the following fields: - :paramref:`{fsp}` - :paramref:`{fs}` """ # noqa if len(self._get(P.META).cat.categories) == 1: logg.warning( "Found only one metastable state. Making it the single main state" ) self.set_final_states_from_metastable_states(None, n_cells=n_cells) return coarse_T = self._get(P.COARSE_T) if method == "eigengap": if self._get(P.EIG) is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()`." ) n_final_states = _eigengap(self._get(P.EIG)["D"], alpha=alpha) + 1 elif method == "eigengap_coarse": if coarse_T is None: raise RuntimeError( "Compute metastable states first as `.compute_metastable_states()`." ) n_final_states = _eigengap(np.sort(np.diag(coarse_T)[::-1]), alpha=alpha) elif method == "top_n": if n_final_states is None: raise ValueError( "Argument `n_final_states` must be != `None` for `method='top_n'`." ) elif n_final_states <= 0: raise ValueError( f"Expected `n_final_states` to be positive, found `{n_final_states}`." ) elif method == "min_self_prob": if min_self_prob is None: raise ValueError( "Argument `min_self_prob` must be != `None` for `method='min_self_prob'`." ) self_probs = pd.Series(np.diag(coarse_T), index=coarse_T.columns) names = self_probs[self_probs.values >= min_self_prob].index self.set_final_states_from_metastable_states(names, n_cells=n_cells) return else: raise ValueError( f"Invalid method `{method!r}`. Valid options are `'eigengap', 'eigengap_coarse', " f"'top_n' and 'min_self_prob'`.") names = coarse_T.columns[np.argsort( np.diag(coarse_T))][-n_final_states:] self.set_final_states_from_metastable_states(names, n_cells=n_cells) def compute_gdpt(self, n_components: int = 10, key_added: str = "gdpt_pseudotime", **kwargs): """ Compute generalized Diffusion pseudotime from [Haghverdi16]_ making use of the real Schur decomposition. Parameters ---------- n_components Number of real Schur vectors to consider. key_added Key in :paramref:`adata` ``.obs`` where to save the pseudotime. **kwargs Keyword arguments for :meth:`cellrank.tl.GPCCA.compute_schur` if Schur decomposition is not found. Returns ------- None Nothing, just updates :paramref:`adata` ``.obs[key_added]`` with the computed pseudotime. """ def _get_dpt_row(e_vals: np.ndarray, e_vecs: np.ndarray, i: int): row = sum( (np.abs(e_vals[eval_ix]) / (1 - np.abs(e_vals[eval_ix])) * (e_vecs[i, eval_ix] - e_vecs[:, eval_ix]))**2 # account for float32 precision for eval_ix in range(0, e_vals.size) if np.abs(e_vals[eval_ix]) < 0.9994) return np.sqrt(row) if "iroot" not in self.adata.uns.keys(): raise KeyError("Key `'iroot'` not found in `adata.uns`.") iroot = self.adata.uns["iroot"] if isinstance(iroot, str): iroot = np.where(self.adata.obs_names == iroot)[0] if not len(iroot): raise ValueError( f"Unable to find cell with name `{self.adata.uns['iroot']!r}` in `adata.obs_names`." ) iroot = iroot[0] if n_components < 2: raise ValueError( f"Expected number of components >= 2, found `{n_components}`.") if self._get(P.SCHUR) is None: logg.warning("No Schur decomposition found. Computing") self.compute_schur(n_components, **kwargs) elif self._get(P.SCHUR_MAT).shape[1] < n_components: logg.warning( f"Requested `{n_components}` components, but only `{self._get(P.SCHUR_MAT).shape[1]}` were found. " f"Recomputing using default values") self.compute_schur(n_components) else: logg.debug("Using cached Schur decomposition") start = logg.info( f"Computing Generalized Diffusion Pseudotime using `n_components={n_components}`" ) Q, eigenvalues = ( self._get(P.SCHUR), self._get(P.EIG)["D"], ) # may have to remove some values if too many converged Q, eigenvalues = Q[:, :n_components], eigenvalues[:n_components] D = _get_dpt_row(eigenvalues, Q, i=iroot) pseudotime = D / np.max(D[np.isfinite(D)]) self.adata.obs[key_added] = pseudotime logg.info(f"Adding `{key_added!r}` to `adata.obs`\n Finish", time=start) @d.dedent def plot_coarse_T( self, show_stationary_dist: bool = True, show_initial_dist: bool = False, cmap: Union[str, mcolors.ListedColormap] = "viridis", xtick_rotation: float = 45, annotate: bool = True, show_cbar: bool = True, title: Optional[str] = None, figsize: Tuple[float, float] = (8, 8), dpi: int = 80, save: Optional[Union[Path, str]] = None, text_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs, ) -> None: """ Plot the coarse-grained transition matrix between metastable states. Parameters ---------- show_stationary_dist Whether to show the stationary distribution, if present. show_initial_dist Whether to show the initial distribution. cmap Colormap to use. xtick_rotation Rotation of ticks on the x-axis. annotate Whether to display the text on each cell. show_cbar Whether to show colorbar. title Title of the figure. %(plotting)s text_kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. **kwargs Keyword arguments for :func:`matplotlib.pyplot.imshow`. Returns ------- %(just_plots)s """ def stylize_dist(ax, data: np.ndarray, xticks_labels: Union[List[str], Tuple[str]] = ()): _ = ax.imshow(data, aspect="auto", cmap=cmap, norm=norm) for spine in ax.spines.values(): spine.set_visible(False) if xticks_labels is not None: ax.set_xticklabels(xticks_labels) ax.set_xticks(np.arange(data.shape[1])) plt.setp( ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode="anchor", ) else: ax.set_xticks([]) ax.tick_params(which="both", top=False, right=False, bottom=False, left=False) ax.set_yticks([]) def annotate_heatmap(im, valfmt: str = "{x:.2f}"): # modified from matplotlib's site data = im.get_array() kw = {"ha": "center", "va": "center"} kw.update(**text_kwargs) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = mpl.ticker.StrMethodFormatter(valfmt) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update( color=_get_black_or_white(im.norm(data[i, j]), cmap)) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"): if ax is None: return if isinstance(valfmt, str): valfmt = mpl.ticker.StrMethodFormatter(valfmt) kw = {"ha": "center", "va": "center"} kw.update(**text_kwargs) for i, val in enumerate(data): kw.update(color=_get_black_or_white(im.norm(val), cmap)) ax.text( i, 0, valfmt(val, None), **kw, ) coarse_T = self._get(P.COARSE_T) coarse_stat_d = self._get(P.COARSE_STAT_D) coarse_init_d = self._get(P.COARSE_INIT_D) if coarse_T is None: raise RuntimeError( "Compute coarse-grained transition matrix first as `.compute_metastable_states()` with `n_states > 1`." ) if show_stationary_dist and coarse_stat_d is None: logg.warning("Coarse stationary distribution is `None`, ignoring") show_stationary_dist = False if show_initial_dist and coarse_init_d is None: logg.warning("Coarse initial distribution is `None`, ignoring") show_initial_dist = False hrs, wrs = [1], [1] if show_stationary_dist: hrs += [0.05] if show_initial_dist: hrs += [0.05] if show_cbar: wrs += [0.025] dont_show_dist = not show_initial_dist and not show_stationary_dist fig = plt.figure(constrained_layout=False, figsize=figsize, dpi=dpi) gs = plt.GridSpec( 1 + show_stationary_dist + show_initial_dist, 1 + show_cbar, height_ratios=hrs, width_ratios=wrs, wspace=0.05, hspace=0.05, ) if isinstance(cmap, str): cmap = plt.get_cmap(cmap) ax = fig.add_subplot(gs[0, 0]) cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None init_ax, stat_ax = None, None labels = list(self.coarse_T.columns) tmp = coarse_T if show_initial_dist: tmp = np.c_[tmp, coarse_stat_d] if show_initial_dist: tmp = np.c_[tmp, coarse_init_d] minn, maxx = np.nanmin(tmp), np.nanmax(tmp) norm = mpl.colors.Normalize(vmin=minn, vmax=maxx) if show_stationary_dist: stat_ax = fig.add_subplot(gs[1, 0]) stylize_dist( stat_ax, np.array(coarse_stat_d).reshape(1, -1), xticks_labels=labels if not show_initial_dist else None, ) stat_ax.yaxis.set_label_position("right") stat_ax.set_ylabel("stationary dist", rotation=0, ha="left", va="center") if show_initial_dist: init_ax = fig.add_subplot(gs[show_stationary_dist + show_initial_dist, 0]) stylize_dist(init_ax, np.array(coarse_init_d).reshape(1, -1), xticks_labels=labels) init_ax.yaxis.set_label_position("right") init_ax.set_ylabel("initial dist", rotation=0, ha="left", va="center") im = ax.imshow(coarse_T, aspect="auto", cmap=cmap, norm=norm, **kwargs) ax.set_title( "coarse-grained transition matrix" if title is None else title) if cax is not None: _ = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, ticks=np.linspace(minn, maxx, 10), format="%0.3f", ) ax.set_yticks(np.arange(coarse_T.shape[0])) ax.set_yticklabels(labels) ax.tick_params( top=False, bottom=dont_show_dist, labeltop=False, labelbottom=dont_show_dist, ) for spine in ax.spines.values(): spine.set_visible(False) if dont_show_dist: ax.set_xticks(np.arange(coarse_T.shape[1])) ax.set_xticklabels(labels) plt.setp( ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode="anchor", ) else: ax.set_xticks([]) ax.set_yticks(np.arange(coarse_T.shape[0] + 1) - 0.5, minor=True) ax.tick_params(which="minor", bottom=dont_show_dist, left=False, top=False) if annotate: annotate_heatmap(im) annotate_dist_ax(stat_ax, coarse_stat_d.values) annotate_dist_ax(init_ax, coarse_init_d) if save: save_fig(fig, save) fig.show() def _compute_meta_for_one_state( self, n_cells: int, cluster_key: Optional[str], en_cutoff: Optional[float], p_thresh: float, ) -> None: start = logg.info("Computing metastable states") logg.warning("For `n_states=1`, stationary distribution is computed") eig = self._get(P.EIG) if (eig is not None and "stationary_dist" in eig and eig["params"]["which"] == "LR"): stationary_dist = eig["stationary_dist"] else: self.compute_eigendecomposition(only_evals=False, which="LR") stationary_dist = self._get(P.EIG)["stationary_dist"] self._set_meta_states( memberships=stationary_dist[:, None], n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) self._set( A.META_PROBS, Lineage( stationary_dist, names=list(self._get(A.META).cat.categories), colors=self._get(A.META_COLORS), ), ) # reset all the things for key in ( A.ABS_PROBS, A.SCHUR, A.SCHUR_MAT, A.COARSE_T, A.COARSE_STAT_D, A.COARSE_STAT_D, ): self._set(key.s, None) logg.info( f"Adding `.{P.META_PROBS}`\n `.{P.META}`\n Finish", time=start, ) def _get_n_states_from_minchi( self, n_states: Union[Tuple[int, int], List[int], Dict[str, int]]) -> int: if self._gpcca is None: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`." ) if not isinstance(n_states, (dict, tuple, list)): raise TypeError( f"Expected `n_states` to be either `dict`, `tuple` or a `list`, " f"found `{type(n_states).__name__}`.") if len(n_states) != 2: raise ValueError( f"Expected `n_states` to be of size `2`, found `{len(n_states)}`." ) if isinstance(n_states, dict): if "min" not in n_states or "max" not in n_states: raise KeyError( f"Expected the dictionary to have `'min'` and `'max'` keys, " f"found `{tuple(n_states.keys())}`.") minn, maxx = n_states["min"], n_states["max"] else: minn, maxx = n_states if minn > maxx: logg.debug( f"Swapping minimum and maximum because `{minn}` > `{maxx}`") minn, maxx = maxx, minn if minn <= 1: raise ValueError(f"Minimum value must be > `1`, found `{minn}`.") elif minn == 2: logg.warning( "In most cases, 2 clusters will always be optimal. " "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`" ) minn = 3 if minn >= maxx: maxx = minn + 1 logg.debug( f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`" ) logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`") return int( np.arange(minn, maxx)[np.argmax(self._gpcca.minChi(minn, maxx))]) @d.dedent def _set_meta_states( self, memberships: np.ndarray, n_cells: Optional[int] = 30, cluster_key: str = "clusters", en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, check_row_sums: bool = True, ) -> None: """ Map a fuzzy clustering to pre-computed annotations to get names and colors. Given the fuzzy clustering we have computed, we would like to select the most likely cells from each state and use these to give each state a name and a color by comparing with pre-computed, categorical cluster annotations. Parameters -------- memberships Fuzzy clustering. %(n_cells)s cluster_key Key from :paramref:`adata` ``.obs`` to get reference cluster annotations. en_cutoff Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost evenly across the reference clusters. p_thresh Only used to detect cell cycle stages. These have to be present in :paramref:`adata` ``.obs`` as `'G2M_score'` and `'S_score'`. check_row_sums Check whether rows in `memberships` sum to `1`. Returns ------- None Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors. Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state. """ if n_cells is None: logg.debug( "Setting the metastable states using metastable assignment") max_assignment = np.argmax(memberships, axis=1) _meta_assignment = pd.Series(index=self.adata.obs_names, data=max_assignment, dtype="category") # sometimes, the assignment can have a missing category and the Lineage creation therefore fails # keep it as ints when `n_cells != None` _meta_assignment.cat.set_categories(list( range(memberships.shape[1])), inplace=True) metastable_states = _meta_assignment.astype(str).astype( "category").copy() not_enough_cells = [] else: logg.debug( "Setting the metastable states using metastable memberships") # select the most likely cells from each metastable state metastable_states, not_enough_cells = self._create_states( memberships, n_cells=n_cells, check_row_sums=check_row_sums, return_not_enough_cells=True, ) not_enough_cells = not_enough_cells.astype("str") # _set_categorical_labels creates the names, we still need to remap the group names orig_cats = metastable_states.cat.categories self._set_categorical_labels( attr_key=A.META.v, color_key=A.META_COLORS.v, pretty_attr_key=P.META.v, add_to_existing_error_msg= "Compute metastable states first as `.compute_metastable_states()`.", categories=metastable_states, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, ) name_mapper = dict(zip(orig_cats, self._get(P.META).cat.categories)) _print_insufficient_number_of_cells( [name_mapper.get(n, n) for n in not_enough_cells], n_cells) logg.debug( "Setting metastable lineage probabilities based on GPCCA membership vectors" ) self._set( A.META_PROBS, Lineage( memberships, names=list(metastable_states.cat.categories), colors=self._get(A.META_COLORS), ), ) def _create_states( self, probs: Union[np.ndarray, Lineage], n_cells: int, check_row_sums: bool = False, return_not_enough_cells: bool = False, ) -> pd.Series: if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") if isinstance(probs, Lineage): probs = probs[[n for n in probs.names if n != "rest"]] a_discrete, not_enough_cells = _fuzzy_to_discrete( a_fuzzy=probs, n_most_likely=n_cells, remove_overlap=False, raise_threshold=0.2, check_row_sums=check_row_sums, ) states = _series_from_one_hot_matrix( membership=a_discrete, index=self.adata.obs_names, names=probs.names if isinstance(probs, Lineage) else None, ) return (states, not_enough_cells) if return_not_enough_cells else states def _check_states_validity(self, n_states: int) -> int: if self._invalid_n_states is not None and n_states in self._invalid_n_states: logg.warning( f"Unable to compute metastable states with `n_states={n_states}` because it will " f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`" ) n_states += 1 # cannot force recomputation of Schur decomposition assert n_states not in self._invalid_n_states, "Sanity check failed." return n_states def _fit_final_states( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, method: str = "krylov", **kwargs, ) -> None: if n_lineages is None or n_lineages == 1: self.compute_eigendecomposition() if n_lineages is None: n_lineages = self.eigendecomposition["eigengap"] + 1 if n_lineages > 1: self.compute_schur(n_lineages + 1, method=method) try: self.compute_metastable_states(n_states=n_lineages, cluster_key=cluster_key, **kwargs) except ValueError: logg.warning( f"Computing `{n_lineages}` metastable states cuts through a block of complex conjugates. " f"Increasing `n_lineages` to {n_lineages + 1}") self.compute_metastable_states(n_states=n_lineages + 1, cluster_key=cluster_key, **kwargs) fs_kwargs = { "n_cells": kwargs["n_cells"] } if "n_cells" in kwargs else {} if n_lineages is None: self.compute_final_states(method="eigengap", **fs_kwargs) else: self.set_final_states_from_metastable_states(**fs_kwargs) @d.dedent # because of fit @d.dedent @inject_docs( ms=P.META, msp=P.META_PROBS, fs=P.FIN, fsp=P.FIN_PROBS, ap=P.ABS_PROBS, dp=P.DIFF_POT, ) def fit( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, keys: Optional[Sequence[str]] = None, method: str = "krylov", compute_absorption_probabilities: bool = True, **kwargs, ): """ Run the pipeline, computing the metastable states, %(final)s states and optionally the absorption probabilities. It is equivalent to running:: if n_lineages is None or n_lineages == 1: compute_eigendecomposition(...) # get the stationary distribution if n_lineages > 1: compute_schur(...) compute_metastable_states(...) if n_lineages is None: compute_final_states(...) else: set_final_states_from_metastable_states(...) if compute_absorption_probabilities: compute_absorption_probabilities(...) Parameters ---------- %(fit)s method Method to use when computing the Schur decomposition. Valid options are: `'krylov'` or `'brandts'`. compute_absorption_probabilities Whether to compute absorption probabilities or only final states. **kwargs Keyword arguments for :meth:`cellrank.tl.estimators.GPCCA.compute_metastable_states`. Returns ------- None Nothing, just makes available the following fields: - :paramref:`{msp}` - :paramref:`{ms}` - :paramref:`{fsp}` - :paramref:`{fs}` - :paramref:`{ap}` - :paramref:`{dp}` """ super().fit( n_lineages=n_lineages, cluster_key=cluster_key, keys=keys, method=method, compute_absorption_probabilities=compute_absorption_probabilities, **kwargs, )
def __new__(cls, clsname, superclasses, attributedict): """ Create a new instance. Parameters ---------- clsname Name of class to be constructed. superclasses List of superclasses. attributedict Dictionary of attributes. """ compute_md, metadata = attributedict.pop(META_KEY, None), [] if compute_md is None: return super().__new__(cls, clsname, superclasses, attributedict) if isinstance(compute_md, str): compute_md = Metadata(attr=compute_md) elif not isinstance(compute_md, (tuple, list)): raise TypeError( f"Expected property metadata to be `list` or `tuple`," f"found `{type(compute_md).__name__!r}`.") elif len(compute_md) == 0: raise ValueError("No metadata found.") else: compute_md, *metadata = (Metadata( attr=md) if isinstance(md, str) else md for md in compute_md) prop_name = PropertyMeta.update_attributes(compute_md, attributedict) plot_name = str(compute_md.plot_fmt).format(prop_name) if compute_md.compute_fmt != F.NO_FUNC: if "_compute" in attributedict: attributedict[str(compute_md.compute_fmt).format( prop_name)] = attributedict["_compute"] if (compute_md.plot_fmt != F.NO_FUNC and VectorPlottable in superclasses and plot_name not in attributedict and not is_abstract(clsname)): raise TypeError( f"Method `{plot_name}` is not implemented for class `{clsname}`." ) for md in metadata: PropertyMeta.update_attributes(md, attributedict) res = super().__new__(cls, clsname, superclasses, attributedict) if compute_md.plot_fmt != F.NO_FUNC and Plottable in res.mro(): # _this is intended singledispatchmethod # unfortunately, `_plot` is not always in attributedict, so we can't just check for it # and res._plot is just a regular function # if this gets buggy in the future, consider switching from singlemethoddispatch setattr( res, plot_name, _delegate_method_dispatch(res._plot, "_plot", prop_name, skip=2), ) return res