Beispiel #1
0
class FinalStates(Plottable):
    """Class dealing with final states."""

    __prop_metadata__ = [
        Metadata(attr=A.FIN, prop=P.FIN, dtype=pd.Series, doc="Final states."),
        Metadata(
            attr=A.FIN_PROBS,
            prop=P.FIN_PROBS,
            dtype=pd.Series,
            doc="Final states probabilities.",
        ),
        Metadata(attr=A.FIN_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray),
    ]

    @abstractmethod
    def set_final_states(self, *args, **kwargs) -> None:  # noqa
        pass

    @abstractmethod
    def compute_final_states(self, *args, **kwargs) -> None:  # noqa
        pass

    @abstractmethod
    def _write_final_states(self, *args, **kwargs) -> None:
        pass
Beispiel #2
0
class TerminalStates(Plottable):
    """Class dealing with terminal states."""

    __prop_metadata__ = [
        Metadata(attr=A.TERM,
                 prop=P.TERM,
                 dtype=pd.Series,
                 doc="Terminal states."),
        Metadata(
            attr=A.TERM_PROBS,
            prop=P.TERM_PROBS,
            dtype=pd.Series,
            doc="Terminal states probabilities.",
        ),
        Metadata(attr=A.TERM_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray),
    ]

    @abstractmethod
    def set_terminal_states(self, *args: Any, **kwargs: Any) -> None:  # noqa
        pass

    @abstractmethod
    def compute_terminal_states(self, *args: Any,
                                **kwargs: Any) -> None:  # noqa
        pass

    @abstractmethod
    def _write_terminal_states(self, *args: Any, **kwargs: Any) -> None:
        pass
Beispiel #3
0
class Macrostates(Plottable):
    """Class dealing with macrostates."""

    __prop_metadata__ = [
        Metadata(attr=A.MACRO, prop=P.MACRO, dtype=pd.Series),
        Metadata(
            attr=A.MACRO_MEMBER,
            prop=P.MACRO_MEMBER,
            dtype=Lineage,
        ),
        Metadata(attr=A.MACRO_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray),
    ]

    @abstractmethod
    def compute_macrostates(self, *args, **kwargs) -> None:  # noqa
        pass
Beispiel #4
0
class MetaStates(Plottable):
    """Class dealing with metastable states."""

    __prop_metadata__ = [
        Metadata(attr=A.META, prop=P.META, dtype=pd.Series, doc="Metastable states."),
        Metadata(
            attr=A.META_PROBS,
            prop=P.META_PROBS,
            dtype=Lineage,
            doc="Metastable states probabilities.",
        ),
        Metadata(attr=A.META_COLORS, prop=P.NO_PROPERTY, dtype=np.ndarray),
    ]

    @abstractmethod
    def compute_metastable_states(self, *args, **kwargs) -> None:  # noqa
        pass
Beispiel #5
0
class LinDrivers(Plottable):  # noqa
    __prop_metadata__ = [
        Metadata(
            attr=A.LIN_DRIVERS,
            prop=P.LIN_DRIVERS,
            dtype=pd.DataFrame,
            doc="Lineage drivers.",
            plot_fmt=F.NO_FUNC,
            # in essence ignore Plottable (could be done by registering DataFrame, but it's ugly
        )
    ]
Beispiel #6
0
class AbsProbs(Plottable):
    """Class dealing with absorption probabilities."""

    __prop_metadata__ = [
        Metadata(
            attr=A.ABS_PROBS,
            prop=P.ABS_PROBS,
            dtype=Lineage,
            doc="Absorption probabilities.",
        ),
        Metadata(
            attr=A.PRIME_DEG,
            prop=P.PRIME_DEG,
            dtype=pd.Series,
            doc="Priming degree.",
        ),
        Metadata(attr=A.LIN_ABS_TIMES, prop=P.LIN_ABS_TIMES, dtype=pd.DataFrame),
    ]

    @abstractmethod
    def _write_absorption_probabilities(self, *args, **kwargs) -> None:
        pass
Beispiel #7
0
class AbsProbs(Plottable):
    """Class dealing with absorption probabilities."""

    __prop_metadata__ = [
        Metadata(
            attr=A.ABS_PROBS,
            prop=P.ABS_PROBS,
            dtype=Lineage,
            doc="Absorption probabilities.",
        ),
        Metadata(
            attr=A.DIFF_POT,
            prop=P.DIFF_POT,
            dtype=pd.Series,
            doc="Differentiation potential.",
        ),
        Metadata(attr=A.LIN_ABS_TIMES, prop=P.LIN_ABS_TIMES, dtype=pd.DataFrame),
    ]

    @abstractmethod
    def _write_absorption_probabilities(self, *args, **kwargs) -> None:
        pass
Beispiel #8
0
class Eigen(VectorPlottable, Decomposable):
    """Class computing the eigendecomposition."""

    __prop_metadata__ = [
        Metadata(attr=A.EIG,
                 prop=P.EIG,
                 dtype=Mapping[str, Any],
                 compute_fmt=F.NO_FUNC)
    ]

    @d.dedent
    @inject_docs(prop=P.EIG)
    def compute_eigendecomposition(
        self,
        k: int = 20,
        which: str = "LR",
        alpha: float = 1,
        only_evals: bool = False,
        ncv: Optional[int] = None,
    ) -> None:
        """
        Compute eigendecomposition of transition matrix.

        Uses a sparse implementation, if possible, and only computes the top :math:`k` eigenvectors
        to speed up the computation. Computes both left and right eigenvectors.

        Parameters
        ----------
        k
            Number of eigenvalues/vectors to compute.
        %(eigen)s
        only_evals
            Compute only eigenvalues.
        ncv
            Number of Lanczos vectors generated.

        Returns
        -------
        None
            Nothing, but updates the following field:

                - :paramref:`{prop}`
        """
        def get_top_k_evals():
            return D[np.flip(np.argsort(D.real))][:k]

        start = logg.info(
            "Computing eigendecomposition of the transition matrix")

        if self.issparse:
            logg.debug(f"Computing top `{k}` eigenvalues for sparse matrix")
            D, V_l = eigs(self.transition_matrix.T, k=k, which=which, ncv=ncv)
            if only_evals:
                self._write_eig_to_adata({
                    "D":
                    get_top_k_evals(),
                    "eigengap":
                    _eigengap(get_top_k_evals().real, alpha),
                    "params": {
                        "which": which,
                        "k": k,
                        "alpha": alpha
                    },
                })
                return
            _, V_r = eigs(self.transition_matrix, k=k, which=which, ncv=ncv)
        else:
            logg.warning(
                "This transition matrix is not sparse, computing full eigendecomposition"
            )
            D, V_l = np.linalg.eig(self.transition_matrix.T)
            if only_evals:
                self._write_eig_to_adata({
                    "D": get_top_k_evals(),
                    "eigengap": _eigengap(D.real, alpha),
                    "params": {
                        "which": which,
                        "k": k,
                        "alpha": alpha
                    },
                })
                return
            _, V_r = np.linalg.eig(self.transition_matrix)

        # Sort the eigenvalues and eigenvectors and take the real part
        logg.debug("Sorting eigenvalues by their real part")
        p = np.flip(np.argsort(D.real))
        D, V_l, V_r = D[p], V_l[:, p], V_r[:, p]
        e_gap = _eigengap(D.real, alpha)

        pi = np.abs(V_l[:, 0].real)
        pi /= np.sum(pi)

        self._write_eig_to_adata(
            {
                "D": D,
                "stationary_dist": pi,
                "V_l": V_l,
                "V_r": V_r,
                "eigengap": e_gap,
                "params": {
                    "which": which,
                    "k": k,
                    "alpha": alpha
                },
            },
            start=start,
        )

    @d.dedent
    def plot_eigendecomposition(self, left: bool = False, *args, **kwargs):
        """
        Plot eigenvectors in an embedding.

        Parameters
        ----------
        left
            Whether to plot left or right eigenvectors.
        %(plot_vectors.parameters)s

        Returns
        -------
        %(plot_vectors.returns)s
        """

        eig = getattr(self, P.EIG.s)

        if eig is None:
            self._plot_vectors(None, P.EIG.s)

        side = "left" if left else "right"
        D, V = (
            eig["D"],
            eig.get(f"V_{side[0]}", None),
        )
        if V is None:
            raise RuntimeError(
                "Compute eigendecomposition first as `.compute_eigendecomposition(..., only_evals=False)`."
            )

        # if irreducible, first rigth e-vec should be const.
        if side == "right":
            # quick check for irreducibility:
            if np.sum(np.isclose(D, 1, rtol=1e2 * EPS, atol=1e2 * EPS)) == 1:
                V[:, 0] = 1.0

        self._plot_vectors(
            V,
            P.EIG.s,
            *args,
            D=D,
            **kwargs,
        )

    @d.dedent
    def plot_spectrum(
        self,
        n: Optional[int] = None,
        real_only: bool = False,
        show_eigengap: bool = True,
        show_all_xticks: bool = True,
        legend_loc: Optional[str] = None,
        title: Optional[str] = None,
        figsize: Optional[Tuple[float, float]] = (5, 5),
        dpi: int = 100,
        save: Optional[Union[str, Path]] = None,
        marker: str = ".",
        **kwargs,
    ) -> None:
        """
        Plot the top eigenvalues in real or complex plane.

        Parameters
        ----------
        n
            Number of eigenvalues to show. If `None`, show all that have been computed.
        real_only
            Whether to plot only the real part of the spectrum.
        show_eigengap
            When `real_only=True`, this determines whether to show the inferred eigengap as
            a dotted line.
        show_all_xticks
            When `real_only=True`, this determines whether to show the indices of all eigenvalues
            on the x-axis.
        legend_loc
            Location parameter for the legend.
        title
            Title of the figure.
        %(plotting)s
        marker
            Marker symbol used, valid options can be found in :mod:`matplotlib.markers`.
        **kwargs
            Keyword arguments for :func:`matplotlib.pyplot.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        eig = getattr(self, P.EIG.s)
        if eig is None:
            raise RuntimeError(
                f"Compute `.{P.EIG}` first as `.{F.COMPUTE.fmt(P.EIG)}()`.")
        if n is None:
            n = len(eig["D"])
        elif n <= 0:
            raise ValueError(f"Expected `n` to be > 0, found `{n}`.")

        if real_only:
            fig = self._plot_real_spectrum(
                n,
                show_eigengap=show_eigengap,
                show_all_xticks=show_all_xticks,
                dpi=dpi,
                figsize=figsize,
                legend_loc=legend_loc,
                title=title,
                marker=marker,
                **kwargs,
            )
        else:
            fig = self._plot_complex_spectrum(
                n,
                dpi=dpi,
                figsize=figsize,
                legend_loc=legend_loc,
                title=title,
                marker=marker,
                **kwargs,
            )

        if save:
            save_fig(fig, save)

        fig.show()

    def _plot_complex_spectrum(
        self,
        n: int,
        dpi: int = 100,
        figsize: Optional[Tuple[float, float]] = (None, None),
        legend_loc: Optional[str] = None,
        title: Optional[str] = None,
        marker: str = ".",
        **kwargs,
    ):
        # define a function to make the data limits rectangular
        def adapt_range(min_, max_, range_):
            return (
                min_ + (max_ - min_) / 2 - range_ / 2,
                min_ + (max_ - min_) / 2 + range_ / 2,
            )

        eig = getattr(self, P.EIG.s)
        D, params = eig["D"][:n], eig["params"]

        # create fiture and axes
        fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize)

        # get the original data ranges
        lam_x, lam_y = D.real, D.imag
        x_min, x_max = np.min(lam_x), np.max(lam_x)
        y_min, y_max = np.min(lam_y), np.max(lam_y)
        x_range, y_range = x_max - x_min, y_max - y_min
        final_range = np.max([x_range, y_range]) + 0.05

        x_min_, x_max_ = adapt_range(x_min, x_max, final_range)
        y_min_, y_max_ = adapt_range(y_min, y_max, final_range)

        # plot the data and the unit circle
        ax.scatter(D.real, D.imag, marker=marker, label="eigenvalue", **kwargs)
        t = np.linspace(0, 2 * np.pi, 500)
        x_circle, y_circle = np.sin(t), np.cos(t)
        ax.plot(x_circle, y_circle, "k-", label="unit circle")

        # set labels, ranges and legend
        ax.set_xlabel(r"Re($\lambda$)")
        ax.set_xlim(x_min_, x_max_)

        ax.set_ylabel(r"Im($\lambda$)")
        ax.set_ylim(y_min_, y_max_)

        key = "real part" if params["which"] == "LR" else "magnitude"
        if title is None:
            title = f"top {n} eigenvalues according to their {key}"

        ax.set_title(title)
        ax.legend(loc=legend_loc)

        return fig

    def _plot_real_spectrum(
        self,
        n: int,
        show_eigengap: bool = True,
        show_all_xticks: bool = True,
        dpi: int = 100,
        figsize: Optional[Tuple[float, float]] = None,
        legend_loc: Optional[str] = None,
        title: Optional[str] = None,
        marker: str = ".",
        **kwargs,
    ):
        eig = getattr(self, P.EIG.s)
        D, params = eig["D"][:n], eig["params"]

        D_real, D_imag = D.real, D.imag
        ixs = np.arange(len(D))
        mask = D_imag == 0

        # plot the top eigenvalues
        fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize)
        if np.any(mask):
            ax.scatter(
                ixs[mask],
                D_real[mask],
                marker=marker,
                label="real eigenvalue",
                **kwargs,
            )
        if np.any(~mask):
            ax.scatter(
                ixs[~mask],
                D_real[~mask],
                marker=marker,
                label="complex eigenvalue",
                **kwargs,
            )

        # add dashed line for the eigengap, ticks, labels, title and legend
        if show_eigengap and eig["eigengap"] < n:
            ax.axvline(eig["eigengap"], label="eigengap", ls="--", lw=1)

        ax.set_xlabel("index")
        if show_all_xticks:
            ax.set_xticks(np.arange(len(D)))
        else:
            ax.xaxis.set_major_locator(MultipleLocator(2.0))
            ax.xaxis.set_major_formatter(FormatStrFormatter("%d"))

        ax.set_ylabel(r"Re($\lambda_i$)")

        key = "real part" if params["which"] == "LR" else "magnitude"
        if title is None:
            title = f"real part of top {n} eigenvalues according to their {key}"

        ax.set_title(title)
        ax.legend(loc=legend_loc)

        return fig
Beispiel #9
0
class Schur(VectorPlottable, Decomposable):
    """Class computing the Schur decomposition."""

    __prop_metadata__ = [
        Metadata(
            attr=A.SCHUR,
            prop=P.SCHUR,
            dtype=np.ndarray,
            compute_fmt=F.NO_FUNC,
            doc="Schur vectors.",
        ),
        Metadata(attr=A.SCHUR_MAT, prop=P.SCHUR_MAT, dtype=np.ndarray),
        Metadata(attr=A.EIG, prop=P.EIG, dtype=Mapping[str, Any]),
        Metadata(attr="_invalid_n_states",
                 prop=P.NO_PROPERTY,
                 dtype=np.ndarray),
        Metadata(attr="_gpcca", prop=P.NO_PROPERTY),
    ]

    @d.dedent
    @inject_docs(schur_vectors=P.SCHUR,
                 schur_matrix=P.SCHUR_MAT,
                 eigendec=P.EIG)
    def compute_schur(
        self,
        n_components: int = 10,
        initial_distribution: Optional[np.ndarray] = None,
        method: str = "krylov",
        which: str = "LR",
        alpha: float = 1,
    ):
        """
        Compute the Schur decomposition.

        Parameters
        ----------
        n_components
            Number of vectors to compute.
        initial_distribution
            Input probability distribution over all cells. If `None`, uniform is chosen.
        method
            Method for calculating the Schur vectors. Valid options are: `'krylov'` or `'brandts'`.
            For benefits of each method, see :class:`msmtools.analysis.dense.gpcca.GPCCA`. The former is
            an iterative procedure that computes a partial, sorted Schur decomposition for large, sparse
            matrices whereas the latter computes a full sorted Schur decomposition of a dense matrix.
        %(eigen)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :paramref:`{schur_vectors}`
                - :paramref:`{schur_matrix}`
                - :paramref:`{eigendec}`
        """

        if n_components < 2:
            raise ValueError(
                f"Number of components must be `>=2`, found `{n_components}`.")

        self._gpcca = _GPCCA(self.transition_matrix,
                             eta=initial_distribution,
                             z=which,
                             method=method)
        start = logg.info("Computing Schur decomposition")

        try:
            self._gpcca._do_schur_helper(n_components)
        except ValueError:
            logg.warning(
                f"Using `{n_components}` components would split a block of complex conjugates. "
                f"Increasing `n_components` to `{n_components + 1}`")
            self._gpcca._do_schur_helper(n_components + 1)

        # make it available for pl
        setattr(self, A.SCHUR.s, self._gpcca.X)
        setattr(self, A.SCHUR_MAT.s, self._gpcca.R)

        self._invalid_n_states = np.array([
            i for i in range(2, len(self._gpcca.eigenvalues))
            if _check_conj_split(self._gpcca.eigenvalues[:i])
        ])
        if len(self._invalid_n_states):
            logg.info(
                f"When computing macrostates, choose a number of states NOT in `{list(self._invalid_n_states)}`"
            )

        self._write_eig_to_adata(
            {
                "D": self._gpcca.eigenvalues,
                "eigengap": _eigengap(self._gpcca.eigenvalues, alpha),
                "params": {
                    "which": which,
                    "k": len(self._gpcca.eigenvalues),
                    "alpha": alpha,
                },
            },
            start=start,
            extra_msg=
            f"\n       `.{P.SCHUR}`\n       `.{P.SCHUR_MAT}`\n    Finish",
        )

    plot_schur = _delegate(prop_name=P.SCHUR.s)(VectorPlottable._plot_vectors)

    @d.dedent
    def plot_schur_matrix(
        self,
        title: Optional[str] = "schur matrix",
        cmap: str = "viridis",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[float] = 80,
        save: Optional[Union[str, Path]] = None,
        **kwargs,
    ):
        """
        Plot the Schur matrix.

        Parameters
        ----------
        title
            Title of the figure.
        cmap
            Colormap to use.
        %(plotting)s
        **kwargs
            Keyword arguments for :func:`seaborn.heatmap`.

        Returns
        -------
        %(just_plots)s
        """

        from seaborn import heatmap

        schur_matrix = getattr(self, P.SCHUR_MAT.s)

        if schur_matrix is None:
            raise RuntimeError(
                f"Compute Schur matrix first as `.{F.COMPUTE.fmt(P.SCHUR)}()`."
            )

        fig, ax = plt.subplots(
            figsize=schur_matrix.shape if figsize is None else figsize,
            dpi=dpi)

        divider = make_axes_locatable(
            ax)  # square=True make the colorbar a bit bigger
        cbar_ax = divider.append_axes("right", size="2%", pad=0.1)

        mask = np.zeros_like(schur_matrix, dtype=np.bool)
        mask[np.tril_indices_from(mask, k=-1)] = True
        mask[~np.isclose(schur_matrix, 0.0)] = False

        vmin, vmax = (
            np.min(schur_matrix[~mask]),
            np.max(schur_matrix[~mask]),
        )

        kwargs["fmt"] = kwargs.get("fmt", "0.2f")
        heatmap(
            schur_matrix,
            cmap=cmap,
            square=True,
            annot=True,
            vmin=vmin,
            vmax=vmax,
            cbar_ax=cbar_ax,
            cbar_kws={"ticks": np.linspace(vmin, vmax, 10)},
            mask=mask,
            xticklabels=[],
            yticklabels=[],
            ax=ax,
            **kwargs,
        )

        ax.set_title(title)

        if save is not None:
            save_fig(fig, path=save)
Beispiel #10
0
class GPCCA(BaseEstimator, Macrostates, Schur, Eigen):
    """
    Generalized Perron Cluster Cluster Analysis :cite:`reuter:18` as implemented in `pyGPCCA <https://pygpcca.readthedocs.io/en/latest/>`_.

    Coarse-grains a discrete Markov chain into a set of macrostates and computes coarse-grained transition probabilities
    among the macrostates. Each macrostate corresponds to an area of the state space, i.e. to a subset of cells. The
    assignment is soft, i.e. each cell is assigned to every macrostate with a certain weight, where weights sum to
    one per cell. Macrostates are computed by maximizing the 'crispness' which can be thought of as a measure for
    minimal overlap between macrostates in a certain inner-product sense. Once the macrostates have been computed,
    we project the large transition matrix onto a coarse-grained transition matrix among the macrostates via
    a Galerkin projection. This projection is based on invariant subspaces of the original transition matrix which
    are obtained using the real Schur decomposition :cite:`reuter:18`.

    Parameters
    ----------
    %(base_estimator.parameters)s
    """  # noqa: E501

    __prop_metadata__ = [
        Metadata(
            attr=A.COARSE_T,
            prop=P.COARSE_T,
            compute_fmt=F.NO_FUNC,
            plot_fmt=F.NO_FUNC,
            dtype=pd.DataFrame,
            doc="Coarse-grained transition matrix.",
        ),
        Metadata(attr=A.TERM_ABS_PROBS, prop=P.NO_PROPERTY, dtype=Lineage),
        Metadata(attr=A.COARSE_INIT_D, prop=P.COARSE_INIT_D, dtype=pd.Series),
        Metadata(attr=A.COARSE_STAT_D, prop=P.COARSE_STAT_D, dtype=pd.Series),
    ]

    def _read_from_adata(self) -> None:
        super()._read_from_adata()
        self._reconstruct_lineage(
            A.TERM_ABS_PROBS,
            self._term_abs_prob_key,
        )

    @inject_docs(
        ms=P.MACRO,
        msp=P.MACRO_MEMBER,
        schur=P.SCHUR.s,
        coarse_T=P.COARSE_T,
        coarse_stat=P.COARSE_STAT_D,
    )
    @d.dedent
    def compute_macrostates(
        self,
        n_states: Optional[Union[int, Tuple[int, int], List[int],
                                 Dict[str, int]]] = None,
        n_cells: Optional[int] = 30,
        use_min_chi: bool = False,
        cluster_key: str = None,
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
    ):
        """
        Compute the macrostates.

        Parameters
        ----------
        n_states
            Number of macrostates. If `None`, use the `eigengap` heuristic.
        %(n_cells)s
        use_min_chi
            Whether to use :meth:`pygpcca.GPCCA.minChi` to calculate the number of macrostates.
            If `True`, ``n_states`` corresponds to a closed interval `[min, max]` inside of which the potentially
            optimal number of macrostates is searched.
        cluster_key
            If a key to cluster labels is given, names and colors of the states will be associated with the clusters.
        %(en_cutoff_p_thresh)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`{msp}`
                - :attr:`{ms}`
                - :attr:`{schur}`
                - :attr:`{coarse_T}`
                - :attr:`{coarse_stat}`
        """

        was_from_eigengap = False

        if use_min_chi:
            n_states = self._get_n_states_from_minchi(n_states)

        if n_states is None:
            if self._get(P.EIG) is None:
                raise RuntimeError(
                    "Compute eigendecomposition first as `.compute_eigendecomposition()` or `.compute_schur()`."
                )
            was_from_eigengap = True
            n_states = self._get(P.EIG)["eigengap"] + 1
            logg.info(f"Using `{n_states}` states based on eigengap")
        elif not isinstance(n_states, int):
            raise ValueError(
                f"Expected `n_states` to be an integer when `use_min_chi=False`, "
                f"found `{type(n_states).__name__!r}`.")

        if n_states <= 0:
            raise ValueError(
                f"Expected `n_states` to be positive or `None`, found `{n_states}`."
            )

        n_states = self._check_states_validity(n_states)
        if n_states == 1:
            self._compute_one_macrostate(
                n_cells=n_cells,
                cluster_key=cluster_key,
                p_thresh=p_thresh,
                en_cutoff=en_cutoff,
            )
            return

        if self._gpcca is None:
            if not was_from_eigengap:
                raise RuntimeError(
                    "Compute Schur decomposition first as `.compute_schur()`.")

            logg.warning(
                f"Number of states `{n_states}` was automatically determined by `eigengap` "
                "but no Schur decomposition was found. Computing with default parameters"
            )
            # this cannot fail if splitting occurs
            # if it were to split, it's automatically increased in `compute_schur`
            self.compute_schur(n_states)

        # pre-computed X
        if self._gpcca._p_X.shape[1] < n_states:
            logg.warning(
                f"Requested more macrostates `{n_states}` than available "
                f"Schur vectors `{self._gpcca._p_X.shape[1]}`. Recomputing the decomposition"
            )

        start = logg.info(f"Computing `{n_states}` macrostates")
        try:
            self._gpcca = self._gpcca.optimize(m=n_states)
        except ValueError as e:
            # this is the following case - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev.
            # in the try block, Schur decomposition with 5 vectors is computed, but it fails (no way of knowing)
            # so in this case, we increase it by 1
            n_states += 1
            logg.warning(f"{e}\nIncreasing `n_states` to `{n_states}`")
            self._gpcca = self._gpcca.optimize(m=n_states)

        self._set_macrostates(
            memberships=self._gpcca.memberships,
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )

        # cache the results and make sure we don't overwrite
        self._set(A.SCHUR, self._gpcca._p_X)
        self._set(A.SCHUR_MAT, self._gpcca._p_R)

        names = self._get(P.MACRO_MEMBER).names

        self._set(
            A.COARSE_T,
            pd.DataFrame(
                self._gpcca.coarse_grained_transition_matrix,
                index=names,
                columns=names,
            ),
        )
        self._set(
            A.COARSE_INIT_D,
            pd.Series(self._gpcca.coarse_grained_input_distribution,
                      index=names),
        )

        # careful here, in case computing the stat. dist failed
        if self._gpcca.coarse_grained_stationary_probability is not None:
            self._set(
                A.COARSE_STAT_D,
                pd.Series(
                    self._gpcca.coarse_grained_stationary_probability,
                    index=names,
                ),
            )
            logg.info(
                f"Adding `.{P.MACRO_MEMBER}`\n"
                f"       `.{P.MACRO}`\n"
                f"       `.{P.SCHUR}`\n"
                f"       `.{P.COARSE_T}`\n"
                f"       `.{P.COARSE_STAT_D}`\n"
                f"    Finish",
                time=start,
            )
        else:
            logg.warning("No stationary distribution found in GPCCA object")
            logg.info(
                f"Adding `.{P.MACRO_MEMBER}`\n"
                f"       `.{P.MACRO}`\n"
                f"       `.{P.SCHUR}`\n"
                f"       `.{P.COARSE_T}`\n"
                f"    Finish",
                time=start,
            )

    @d.dedent
    @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS)
    def set_terminal_states_from_macrostates(
        self,
        names: Optional[Union[Sequence[str], Mapping[str, str], str]] = None,
        n_cells: int = 30,
    ):
        """
        Manually select terminal states from macrostates.

        Parameters
        ----------
        names
            Names of the macrostates to be marked as terminal. Multiple states can be combined using `','`,
            such as ``["Alpha, Beta", "Epsilon"]``.  If a :class:`dict`, keys correspond to the names
            of the macrostates and the values to the new names.  If `None`, select all macrostates.
        %(n_cells)s

        Returns
        -------
        None
            Nothing, just updates the following fields:

                - :attr:`{fsp}`
                - :attr:`{fs}`
        """

        if not isinstance(n_cells, int):
            raise TypeError(
                f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__}`."
            )

        if n_cells <= 0:
            raise ValueError(
                f"Expected `n_cells` to be positive, found `{n_cells}`.")

        probs = self._get(P.MACRO_MEMBER)
        if probs is None:
            raise RuntimeError(
                "Compute macrostates first as `.compute_macrostates()`.")

        rename = True
        if names is None:
            names = probs.names
            rename = False
        if isinstance(names, str):
            names = [names]
            rename = False
        if not isinstance(names, dict):
            names = {n: n for n in names}
            rename = False

        if not len(names):
            raise ValueError("No macrostates have been selected.")

        if not all(isinstance(old, str) for old in names.keys()):
            raise TypeError("Not all new names are strings.")

        if not all(isinstance(new, (str, int)) for new in names.values()):
            raise TypeError(
                "Not all macrostates names are strings or integers.")

        # this also checks that the names are correct before renaming
        macrostates_probs = probs[list(names.keys())]

        # we do this also here because if `rename_terminal_states` fails
        # invalid states would've been written to this object and nothing to adata
        new_names = {k: str(v) for k, v in names.items()}
        names_after_renaming = [new_names.get(n, n) for n in probs.names]
        if len(set(names_after_renaming)) != probs.shape[1]:
            raise ValueError(
                f"After renaming, the names will not be unique: `{names_after_renaming}`."
            )

        if probs.shape[1] == 1:
            self._set(A.TERM, self._create_states(probs, n_cells=n_cells))
            self._set(A.TERM_COLORS, self._get(A.MACRO_COLORS))
            self._set(
                A.TERM_PROBS,
                pd.Series(probs.X.squeeze() / probs.X.max(),
                          index=self.adata.obs_names),
            )
            self._set(A.TERM_ABS_PROBS, probs)
            if rename:
                # access lineage renames join states, e.g. 'Alpha, Beta' becomes 'Alpha or Beta' + whitespace stripping
                self.rename_terminal_states(
                    dict(zip(self._get(P.TERM).cat.categories,
                             names.values())))

            self._write_terminal_states()
            return

        # compute the aggregated probability of being a initial/terminal state (no matter which)
        scaled_probs = macrostates_probs.copy()
        scaled_probs /= scaled_probs.max(0)

        self._set(A.TERM,
                  self._create_states(macrostates_probs, n_cells=n_cells))
        self._set(A.TERM_PROBS,
                  pd.Series(scaled_probs.X.max(1), index=self.adata.obs_names))
        self._set(
            A.TERM_COLORS,
            macrostates_probs[list(self._get(P.TERM).cat.categories)].colors,
        )
        self._set(A.TERM_ABS_PROBS, scaled_probs)
        if rename:
            self.rename_terminal_states(
                dict(zip(self._get(P.TERM).cat.categories, names.values())))

        self._write_terminal_states()

    @inject_docs(fs=P.TERM, fsp=P.TERM_PROBS)
    @d.dedent
    def compute_terminal_states(
        self,
        method: str = "stability",
        n_cells: int = 30,
        alpha: Optional[float] = 1,
        stability_threshold: float = 0.96,
        n_states: Optional[int] = None,
    ):
        """
        Automatically select terminal states from macrostates.

        Parameters
        ----------
        method
            One of following:

                - `'eigengap'` - select the number of states based on the `eigengap` of the transition matrix.
                - `'eigengap_coarse'` - select the number of states based on the `eigengap` of the diagonal
                  of the coarse-grained transition matrix.
                - `'top_n'` - select top ``n_states`` based on the probability of the diagonal
                  of the coarse-grained transition matrix.
                - `'stability'` - select states which have a stability index >= ``stability_threshold``. The stability
                  index is given by the diagonal elements of the coarse-grained transition matrix.
        %(n_cells)s
        alpha
            Weight given to the deviation of an eigenvalue from one. Used when ``method='eigengap'``
            or ``method='eigengap_coarse'``.
        stability_threshold
            Threshold used when ``method='stability'``.
        n_states
            Numer of states used when ``method='top_n'``.

        Returns
        -------
        None
            Nothing, just updates the following fields:

                - :attr:`{fsp}`
                - :attr:`{fs}`
        """

        if len(self._get(P.MACRO).cat.categories) == 1:
            logg.warning(
                "Found only one macrostate. Making it the single main state")
            self.set_terminal_states_from_macrostates(None, n_cells=n_cells)
            return

        coarse_T = self._get(P.COARSE_T)

        if method == "eigengap":
            if self._get(P.EIG) is None:
                raise RuntimeError(
                    "Compute eigendecomposition first as `.compute_eigendecomposition()`."
                )
            n_states = _eigengap(self._get(P.EIG)["D"], alpha=alpha) + 1
        elif method == "eigengap_coarse":
            if coarse_T is None:
                raise RuntimeError(
                    "Compute macrostates first as `.compute_macrostates()`.")
            n_states = _eigengap(np.sort(np.diag(coarse_T)[::-1]), alpha=alpha)
        elif method == "top_n":
            if n_states is None:
                raise ValueError(
                    "Argument `n_states` must be != `None` for `method='top_n'`."
                )
            elif n_states <= 0:
                raise ValueError(
                    f"Expected `n_states` to be positive, found `{n_states}`.")
        elif method == "stability":
            if stability_threshold is None:
                raise ValueError(
                    "Argument `stability_threshold` must be != `None` for `method='stability'`."
                )
            self_probs = pd.Series(np.diag(coarse_T), index=coarse_T.columns)
            names = self_probs[self_probs.values >= stability_threshold].index
            self.set_terminal_states_from_macrostates(names, n_cells=n_cells)
            return
        else:
            raise ValueError(
                f"Invalid method `{method!r}`. Valid options are `'eigengap', 'eigengap_coarse', "
                f"'top_n' and 'min_self_prob'`.")

        names = coarse_T.columns[np.argsort(np.diag(coarse_T))][-n_states:]
        self.set_terminal_states_from_macrostates(names, n_cells=n_cells)

    def compute_gdpt(self,
                     n_components: int = 10,
                     key_added: str = "gdpt_pseudotime",
                     **kwargs):
        """
        Compute generalized Diffusion pseudotime from :cite:`haghverdi:16` using the real Schur decomposition.

        Parameters
        ----------
        n_components
            Number of real Schur vectors to consider.
        key_added
            Key in :attr:`adata` ``.obs`` where to save the pseudotime.
        kwargs
            Keyword arguments for :meth:`cellrank.tl.GPCCA.compute_schur` if Schur decomposition is not found.

        Returns
        -------
        None
            Nothing, just updates :attr:`adata` ``.obs[key_added]`` with the computed pseudotime.
        """
        def _get_dpt_row(e_vals: np.ndarray, e_vecs: np.ndarray, i: int):
            row = sum(
                (np.abs(e_vals[eval_ix]) / (1 - np.abs(e_vals[eval_ix])) *
                 (e_vecs[i, eval_ix] - e_vecs[:, eval_ix]))**2
                # account for float32 precision
                for eval_ix in range(0, e_vals.size)
                if np.abs(e_vals[eval_ix]) < 0.9994)

            return np.sqrt(row)

        if "iroot" not in self.adata.uns.keys():
            raise KeyError("Key `'iroot'` not found in `adata.uns`.")

        iroot = self.adata.uns["iroot"]
        if isinstance(iroot, str):
            iroot = np.where(self.adata.obs_names == iroot)[0]
            if not len(iroot):
                raise ValueError(
                    f"Unable to find cell with name `{self.adata.uns['iroot']!r}` in `adata.obs_names`."
                )
            iroot = iroot[0]

        if n_components < 2:
            raise ValueError(
                f"Expected number of components >= 2, found `{n_components}`.")

        if self._get(P.SCHUR) is None:
            logg.warning("No Schur decomposition found. Computing")
            self.compute_schur(n_components, **kwargs)
        elif self._get(P.SCHUR_MAT).shape[1] < n_components:
            logg.warning(
                f"Requested `{n_components}` components, but only `{self._get(P.SCHUR_MAT).shape[1]}` were found. "
                f"Recomputing using default values")
            self.compute_schur(n_components)
        else:
            logg.debug("Using cached Schur decomposition")

        start = logg.info(
            f"Computing Generalized Diffusion Pseudotime using `n_components={n_components}`"
        )

        Q, eigenvalues = (
            self._get(P.SCHUR),
            self._get(P.EIG)["D"],
        )
        # may have to remove some values if too many converged
        Q, eigenvalues = Q[:, :n_components], eigenvalues[:n_components]
        D = _get_dpt_row(eigenvalues, Q, i=iroot)
        pseudotime = D / np.max(D[np.isfinite(D)])

        self.adata.obs[key_added] = pseudotime

        logg.info(f"Adding `{key_added!r}` to `adata.obs`\n    Finish",
                  time=start)

    @d.dedent
    def plot_coarse_T(
        self,
        show_stationary_dist: bool = True,
        show_initial_dist: bool = False,
        cmap: Union[str, mcolors.ListedColormap] = "viridis",
        xtick_rotation: float = 45,
        annotate: bool = True,
        show_cbar: bool = True,
        title: Optional[str] = None,
        figsize: Tuple[float, float] = (8, 8),
        dpi: int = 80,
        save: Optional[Union[Path, str]] = None,
        text_kwargs: Mapping[str, Any] = MappingProxyType({}),
        **kwargs,
    ) -> None:
        """
        Plot the coarse-grained transition matrix between macrostates.

        Parameters
        ----------
        show_stationary_dist
            Whether to show the stationary distribution, if present.
        show_initial_dist
            Whether to show the initial distribution.
        cmap
            Colormap to use.
        xtick_rotation
            Rotation of ticks on the x-axis.
        annotate
            Whether to display the text on each cell.
        show_cbar
            Whether to show colorbar.
        title
            Title of the figure.
        %(plotting)s
        text_kwargs
            Keyword arguments for :func:`matplotlib.pyplot.text`.
        kwargs
            Keyword arguments for :func:`matplotlib.pyplot.imshow`.

        Returns
        -------
        %(just_plots)s
        """
        def stylize_dist(ax,
                         data: np.ndarray,
                         xticks_labels: Union[List[str], Tuple[str]] = ()):
            _ = ax.imshow(data, aspect="auto", cmap=cmap, norm=norm)
            for spine in ax.spines.values():
                spine.set_visible(False)

            if xticks_labels is not None:
                ax.set_xticks(np.arange(data.shape[1]))
                ax.set_xticklabels(xticks_labels)
                plt.setp(
                    ax.get_xticklabels(),
                    rotation=xtick_rotation,
                    ha="right",
                    rotation_mode="anchor",
                )
            else:
                ax.set_xticks([])
                ax.tick_params(which="both",
                               top=False,
                               right=False,
                               bottom=False,
                               left=False)

            ax.set_yticks([])

        def annotate_heatmap(im, valfmt: str = "{x:.2f}"):
            # modified from matplotlib's site

            data = im.get_array()
            kw = {"ha": "center", "va": "center"}
            kw.update(**text_kwargs)

            # Get the formatter in case a string is supplied
            if isinstance(valfmt, str):
                valfmt = mpl.ticker.StrMethodFormatter(valfmt)

            # Loop over the data and create a `Text` for each "pixel".
            # Change the text's color depending on the data.
            texts = []
            for i in range(data.shape[0]):
                for j in range(data.shape[1]):
                    kw.update(
                        color=_get_black_or_white(im.norm(data[i, j]), cmap))
                    text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
                    texts.append(text)

        def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"):
            if ax is None:
                return

            if isinstance(valfmt, str):
                valfmt = mpl.ticker.StrMethodFormatter(valfmt)

            kw = {"ha": "center", "va": "center"}
            kw.update(**text_kwargs)

            for i, val in enumerate(data):
                kw.update(color=_get_black_or_white(im.norm(val), cmap))
                ax.text(
                    i,
                    0,
                    valfmt(val, None),
                    **kw,
                )

        coarse_T = self._get(P.COARSE_T)
        coarse_stat_d = self._get(P.COARSE_STAT_D)
        coarse_init_d = self._get(P.COARSE_INIT_D)

        if coarse_T is None:
            raise RuntimeError(
                "Compute coarse-grained transition matrix first as `.compute_macrostates()` with `n_states > 1`."
            )

        if show_stationary_dist and coarse_stat_d is None:
            logg.warning("Coarse stationary distribution is `None`, ignoring")
            show_stationary_dist = False
        if show_initial_dist and coarse_init_d is None:
            logg.warning("Coarse initial distribution is `None`, ignoring")
            show_initial_dist = False

        hrs, wrs = [1], [1]
        if show_stationary_dist:
            hrs += [0.05]
        if show_initial_dist:
            hrs += [0.05]
        if show_cbar:
            wrs += [0.025]

        dont_show_dist = not show_initial_dist and not show_stationary_dist

        fig = plt.figure(constrained_layout=False, figsize=figsize, dpi=dpi)
        gs = plt.GridSpec(
            1 + show_stationary_dist + show_initial_dist,
            1 + show_cbar,
            height_ratios=hrs,
            width_ratios=wrs,
            wspace=0.05,
            hspace=0.05,
        )
        if isinstance(cmap, str):
            cmap = plt.get_cmap(cmap)

        ax = fig.add_subplot(gs[0, 0])
        cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None
        init_ax, stat_ax = None, None

        labels = list(self.coarse_T.columns)

        tmp = coarse_T
        if show_initial_dist:
            tmp = np.c_[tmp, coarse_stat_d]
        if show_initial_dist:
            tmp = np.c_[tmp, coarse_init_d]

        minn, maxx = np.nanmin(tmp), np.nanmax(tmp)
        norm = mpl.colors.Normalize(vmin=minn, vmax=maxx)

        if show_stationary_dist:
            stat_ax = fig.add_subplot(gs[1, 0])
            stylize_dist(
                stat_ax,
                np.array(coarse_stat_d).reshape(1, -1),
                xticks_labels=labels if not show_initial_dist else None,
            )
            stat_ax.yaxis.set_label_position("right")
            stat_ax.set_ylabel("stationary dist",
                               rotation=0,
                               ha="left",
                               va="center")

        if show_initial_dist:
            init_ax = fig.add_subplot(gs[show_stationary_dist +
                                         show_initial_dist, 0])
            stylize_dist(init_ax,
                         np.array(coarse_init_d).reshape(1, -1),
                         xticks_labels=labels)

            init_ax.yaxis.set_label_position("right")
            init_ax.set_ylabel("initial dist",
                               rotation=0,
                               ha="left",
                               va="center")

        im = ax.imshow(coarse_T, aspect="auto", cmap=cmap, norm=norm, **kwargs)
        ax.set_title(
            "coarse-grained transition matrix" if title is None else title)

        if cax is not None:
            _ = mpl.colorbar.ColorbarBase(
                cax,
                cmap=cmap,
                norm=norm,
                ticks=np.linspace(minn, maxx, 10),
                format="%0.3f",
            )

        ax.set_yticks(np.arange(coarse_T.shape[0]))
        ax.set_yticklabels(labels)

        ax.tick_params(
            top=False,
            bottom=dont_show_dist,
            labeltop=False,
            labelbottom=dont_show_dist,
        )

        for spine in ax.spines.values():
            spine.set_visible(False)

        if dont_show_dist:
            ax.set_xticks(np.arange(coarse_T.shape[1]))
            ax.set_xticklabels(labels)
            plt.setp(
                ax.get_xticklabels(),
                rotation=xtick_rotation,
                ha="right",
                rotation_mode="anchor",
            )
        else:
            ax.set_xticks([])

        ax.set_yticks(np.arange(coarse_T.shape[0] + 1) - 0.5, minor=True)
        ax.tick_params(which="minor",
                       bottom=dont_show_dist,
                       left=False,
                       top=False)

        if annotate:
            annotate_heatmap(im)
            if show_stationary_dist:
                annotate_dist_ax(stat_ax, coarse_stat_d.values)
            if show_initial_dist:
                annotate_dist_ax(init_ax, coarse_init_d)

        if save:
            save_fig(fig, save)

    @d.dedent
    def plot_macrostate_composition(
        self,
        key: str,
        width: float = 0.8,
        title: Optional[str] = None,
        labelrot: float = 45,
        legend_loc: Optional[str] = "upper right out",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        show: bool = True,
    ) -> Optional[Axes]:
        """
        Plot stacked histogram of macrostates over categorical annotations.

        Parameters
        ----------
        %(adata)s
        key
            Key from :attr:`adata` ``.obs`` containing categorical annotations.
        width
            Bar width in `[0, 1]`.
        title
            Title of the figure. If `None`, create one automatically.
        labelrot
            Rotation of labels on x-axis.
        legend_loc
            Position of the legend. If `None`, don't show legend.
        %(plotting)s
        show
            If `False`, return :class:`matplotlib.pyplot.Axes`.

        Returns
        -------
        :class:`matplotlib.pyplot.Axes`
            The axis object if ``show=False``.
        %(just_plots)s
        """
        from cellrank.pl._utils import _position_legend

        macrostates = self._get(P.MACRO)
        if macrostates is None:
            raise RuntimeError(
                "Compute macrostates first as `.compute_macrostates()`.")
        if key not in self.adata.obs:
            raise KeyError(f"Key `{key}` not found in `adata.obs`.")
        if not is_categorical_dtype(self.adata.obs[key]):
            raise TypeError(
                f"Expected `adata.obs[{key!r}]` to be `categorical`, "
                f"found `{infer_dtype(self.adata.obs[key])}`.")

        mask = ~macrostates.isnull()
        df = (pd.DataFrame({
            "macrostates": macrostates,
            key: self.adata.obs[key]
        })[mask].groupby([key, "macrostates"]).size())
        try:
            cats_colors = self.adata.uns[f"{key}_colors"]
        except KeyError:
            cats_colors = _create_categorical_colors(
                len(self.adata.obs[key].cat.categories))
        cat_color_mapper = dict(
            zip(self.adata.obs[key].cat.categories, cats_colors))
        x_indices = np.arange(len(macrostates.cat.categories))
        bottom = np.zeros_like(x_indices, dtype=np.float32)

        width = min(1, max(0, width))
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
        for cat, color in cat_color_mapper.items():
            frequencies = df.loc[cat]
            # do not add to legend if category is missing
            if np.sum(frequencies) > 0:
                ax.bar(
                    x_indices,
                    frequencies,
                    width,
                    label=cat,
                    color=color,
                    bottom=bottom,
                    ec="black",
                    lw=0.5,
                )
                bottom += np.array(frequencies)

        ax.set_xticks(x_indices)
        ax.set_xticklabels(
            # assuming at least 1 category
            frequencies.index,
            rotation=labelrot,
            ha="center" if labelrot in (0, 90) else "right",
        )
        y_max = bottom.max()
        ax.set_ylim([0, y_max + 0.05 * y_max])
        ax.set_yticks(np.linspace(0, y_max, 5))
        ax.margins(0.05)

        ax.set_xlabel("macrostate")
        ax.set_ylabel("frequency")
        if title is None:
            title = f"distribution over {key}"
        ax.set_title(title)
        if legend_loc not in (None, "none"):
            _position_legend(ax, legend_loc=legend_loc)

        if save is not None:
            save_fig(fig, save)

        if not show:
            return ax

    def _compute_one_macrostate(
        self,
        n_cells: int,
        cluster_key: Optional[str],
        en_cutoff: Optional[float],
        p_thresh: float,
    ) -> None:
        start = logg.warning(
            "For 1 macrostate, stationary distribution is computed")

        eig = self._get(P.EIG)
        if (eig is not None and "stationary_dist" in eig
                and eig["params"]["which"] == "LR"):
            stationary_dist = eig["stationary_dist"]
        else:
            self.compute_eigendecomposition(only_evals=False, which="LR")
            stationary_dist = self._get(P.EIG)["stationary_dist"]

        self._set_macrostates(
            memberships=stationary_dist[:, None],
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )
        self._set(
            A.MACRO_MEMBER,
            Lineage(
                stationary_dist,
                names=list(self._get(A.MACRO).cat.categories),
                colors=self._get(A.MACRO_COLORS),
            ),
        )

        # reset all the things
        for key in (
                A.ABS_PROBS,
                A.PRIME_DEG,
                A.SCHUR,
                A.SCHUR_MAT,
                A.COARSE_T,
                A.COARSE_STAT_D,
                A.COARSE_STAT_D,
        ):
            self._set(key.s, None)

        logg.info(
            f"Adding `.{P.MACRO_MEMBER}`\n        `.{P.MACRO}`\n    Finish",
            time=start,
        )

    def _get_n_states_from_minchi(
            self, n_states: Union[Tuple[int, int], List[int],
                                  Dict[str, int]]) -> int:
        if self._gpcca is None:
            raise RuntimeError(
                "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`."
            )

        if not isinstance(n_states, (dict, tuple, list)):
            raise TypeError(
                f"Expected `n_states` to be either `dict`, `tuple` or a `list`, "
                f"found `{type(n_states).__name__}`.")
        if len(n_states) != 2:
            raise ValueError(
                f"Expected `n_states` to be of size `2`, found `{len(n_states)}`."
            )

        if isinstance(n_states, dict):
            if "min" not in n_states or "max" not in n_states:
                raise KeyError(
                    f"Expected the dictionary to have `'min'` and `'max'` keys, "
                    f"found `{tuple(n_states.keys())}`.")
            minn, maxx = n_states["min"], n_states["max"]
        else:
            minn, maxx = n_states

        if minn > maxx:
            logg.debug(
                f"Swapping minimum and maximum because `{minn}` > `{maxx}`")
            minn, maxx = maxx, minn

        if minn <= 1:
            raise ValueError(f"Minimum value must be > `1`, found `{minn}`.")
        elif minn == 2:
            logg.warning(
                "In most cases, 2 clusters will always be optimal. "
                "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`"
            )
            minn = 3

        if minn >= maxx:
            maxx = minn + 1
            logg.debug(
                f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`"
            )

        logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`")

        return int(
            np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn,
                                                                   maxx))])

    @d.dedent
    def _set_macrostates(
        self,
        memberships: np.ndarray,
        n_cells: Optional[int] = 30,
        cluster_key: str = "clusters",
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
        check_row_sums: bool = True,
    ) -> None:
        """
        Map fuzzy clustering to pre-computed annotations to get names and colors.

        Given the fuzzy clustering, we would like to select the most likely cells from each state and use these to
        give each state a name and a color by comparing with pre-computed, categorical cluster annotations.

        Parameters
        ----------
        memberships
            Fuzzy clustering.
        %(n_cells)s
        cluster_key
            Key from :attr:`adata` ``.obs`` to get reference cluster annotations.
        en_cutoff
            Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when
            one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost
            evenly across the reference clusters.
        p_thresh
            Only used to detect cell cycle stages. These have to be present in :attr:`adata` ``.obs`` as
            `'G2M_score'` and `'S_score'`.
        check_row_sums
            Check whether rows in `memberships` sum to `1`.

        Returns
        -------
        None
            Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors.
            Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state.
        """

        if n_cells is None:
            logg.debug("Setting the macrostates using macrostate assignment")

            # fmt: off
            max_assignment = np.argmax(memberships, axis=1)
            _macro_assignment = pd.Series(index=self.adata.obs_names,
                                          data=max_assignment,
                                          dtype="category")
            # sometimes, the assignment can have a missing category and the Lineage creation therefore fails
            # keep it as ints when `n_cells != None`
            _macro_assignment = _macro_assignment.cat.set_categories(
                list(range(memberships.shape[1])))
            macrostates = _macro_assignment.astype(str).astype(
                "category").copy()
            not_enough_cells = []
            # fmt: on
        else:
            logg.debug("Setting the macrostates using macrostates memberships")

            # select the most likely cells from each macrostate
            macrostates, not_enough_cells = self._create_states(
                memberships,
                n_cells=n_cells,
                check_row_sums=check_row_sums,
                return_not_enough_cells=True,
            )
            not_enough_cells = not_enough_cells.astype("str")

        # _set_categorical_labels creates the names, we still need to remap the group names
        orig_cats = macrostates.cat.categories
        self._set_categorical_labels(
            attr_key=A.MACRO.v,
            color_key=A.MACRO_COLORS.v,
            pretty_attr_key=P.MACRO.v,
            add_to_existing_error_msg=
            "Compute macrostates first as `.compute_macrostates()`.",
            categories=macrostates,
            cluster_key=cluster_key,
            en_cutoff=en_cutoff,
            p_thresh=p_thresh,
            add_to_existing=False,
        )

        name_mapper = dict(zip(orig_cats, self._get(P.MACRO).cat.categories))
        _print_insufficient_number_of_cells(
            [name_mapper.get(n, n) for n in not_enough_cells], n_cells)
        logg.debug(
            "Setting macrostates memberships based on GPCCA membership vectors"
        )

        self._set(
            A.MACRO_MEMBER,
            Lineage(
                memberships,
                names=list(macrostates.cat.categories),
                colors=self._get(A.MACRO_COLORS),
            ),
        )

    def _create_states(
        self,
        probs: Union[np.ndarray, Lineage],
        n_cells: int,
        check_row_sums: bool = False,
        return_not_enough_cells: bool = False,
    ) -> pd.Series:
        if n_cells <= 0:
            raise ValueError(
                f"Expected `n_cells` to be positive, found `{n_cells}`.")

        a_discrete, not_enough_cells = _fuzzy_to_discrete(
            a_fuzzy=probs,
            n_most_likely=n_cells,
            remove_overlap=False,
            raise_threshold=0.2,
            check_row_sums=check_row_sums,
        )

        states = _series_from_one_hot_matrix(
            membership=a_discrete,
            index=self.adata.obs_names,
            names=probs.names if isinstance(probs, Lineage) else None,
        )

        return (states,
                not_enough_cells) if return_not_enough_cells else states

    def _check_states_validity(self, n_states: int) -> int:
        if self._invalid_n_states is not None and n_states in self._invalid_n_states:
            logg.warning(
                f"Unable to compute macrostates with `n_states={n_states}` because it will "
                f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`"
            )
            n_states += 1  # cannot force recomputation of the Schur decomposition
            assert n_states not in self._invalid_n_states, "Sanity check failed."

        return n_states

    def _fit_terminal_states(
        self,
        n_lineages: Optional[int] = None,
        cluster_key: Optional[str] = None,
        method: str = "krylov",
        **kwargs,
    ) -> None:
        if n_lineages is None or n_lineages == 1:
            self.compute_eigendecomposition()
            if n_lineages is None:
                n_lineages = self.eigendecomposition["eigengap"] + 1

        if n_lineages > 1:
            self.compute_schur(n_lineages, method=method)

        try:
            self.compute_macrostates(n_states=n_lineages,
                                     cluster_key=cluster_key,
                                     **kwargs)
        except ValueError:
            logg.warning(
                f"Computing `{n_lineages}` macrostates cuts through a block of complex conjugates. "
                f"Increasing `n_lineages` to {n_lineages + 1}")
            self.compute_macrostates(n_states=n_lineages + 1,
                                     cluster_key=cluster_key,
                                     **kwargs)

        fs_kwargs = {
            "n_cells": kwargs["n_cells"]
        } if "n_cells" in kwargs else {}

        if n_lineages is None:
            self.compute_terminal_states(method="eigengap", **fs_kwargs)
        else:
            self.set_terminal_states_from_macrostates(**fs_kwargs)

    @d.dedent  # because of fit
    @d.dedent
    @inject_docs(
        ms=P.MACRO,
        msp=P.MACRO_MEMBER,
        fs=P.TERM,
        fsp=P.TERM_PROBS,
        ap=P.ABS_PROBS,
        pd=P.PRIME_DEG,
    )
    def fit(
        self,
        n_lineages: Optional[int] = None,
        cluster_key: Optional[str] = None,
        keys: Optional[Sequence[str]] = None,
        method: str = "krylov",
        compute_absorption_probabilities: bool = True,
        **kwargs,
    ):
        """
        Run the pipeline, computing the macrostates, %(initial_or_terminal)s states \
        and optionally the absorption probabilities.

        It is equivalent to running::

            if n_lineages is None or n_lineages == 1:
                compute_eigendecomposition(...)  # get the stationary distribution
            if n_lineages > 1:
                compute_schur(...)

            compute_macrostates(...)

            if n_lineages is None:
                compute_terminal_states(...)
            else:
                set_terminal_states_from_macrostates(...)

            if compute_absorption_probabilities:
                compute_absorption_probabilities(...)

        Parameters
        ----------
        %(fit)s
        method
            Method to use when computing the Schur decomposition. Valid options are: `'krylov'` or `'brandts'`.
        compute_absorption_probabilities
            Whether to compute the absorption probabilities or only the %(initial_or_terminal)s states.
        kwargs
            Keyword arguments for :meth:`cellrank.tl.estimators.GPCCA.compute_macrostates`.

        Returns
        -------
        None
            Nothing, just makes available the following fields:

                - :attr:`{msp}`
                - :attr:`{ms}`
                - :attr:`{fsp}`
                - :attr:`{fs}`
                - :attr:`{ap}`
                - :attr:`{pd}`
        """

        super().fit(
            n_lineages=n_lineages,
            cluster_key=cluster_key,
            keys=keys,
            method=method,
            compute_absorption_probabilities=compute_absorption_probabilities,
            **kwargs,
        )

    @d.dedent
    def _compute_initial_states(self,
                                n_states: int = 1,
                                n_cells: int = 30) -> None:
        """
        Compute initial states from macrostates.

        Parameters
        ----------
        n_states
            Number of initial states.
        %(n_cells)s

        Returns
        -------
        %(set_initial_states_from_macrostates.returns)s
        """

        if n_states <= 0:
            raise ValueError(
                f"Expected `n_states` to be positive, found `{n_states}`.")

        if n_cells <= 0:
            raise ValueError(
                f"Expected `n_cells` to be positive, found `{n_cells}`.")

        probs = self._get(P.MACRO_MEMBER)

        if probs is None:
            raise RuntimeError(
                "Compute macrostates first as `.compute_macrostates()`.")

        if n_states > probs.shape[1]:
            raise ValueError(
                f"Requested `{n_states}` initial states, but only `{probs.shape[1]}` macrostates have been computed."
            )

        if probs.shape[1] == 1:
            self._set_initial_states_from_macrostates(n_cells=n_cells)
            return

        stat_dist = self._get(P.COARSE_STAT_D)
        if stat_dist is None:
            raise RuntimeError(
                "No coarse-grained stationary distribution found.")

        self._set_initial_states_from_macrostates(
            stat_dist[np.argsort(stat_dist)][:n_states].index, n_cells=n_cells)

    @d.get_sections(base="set_initial_states_from_macrostates",
                    sections=["Returns"])
    @d.dedent
    @inject_docs(key=TermStatesKey.BACKWARD.s,
                 probs_key=_probs(TermStatesKey.BACKWARD.s))
    def _set_initial_states_from_macrostates(
        self,
        names: Optional[Union[Iterable[str], str]] = None,
        n_cells: int = 30,
    ) -> None:
        """
        Manually select initial states from macrostates.

        Note that no check is performed to ensure initial and terminal states are distinct.

        Parameters
        ----------
        names
            Names of the macrostates to be marked as initial states. Multiple states can be combined using `','`,
            such as `["Alpha, Beta", "Epsilon"]`.
        %(n_cells)s

        Returns
        -------
        None
            Nothing, just writes to :attr:`adata`:

                - ``.obs[{key!r}]`` - probability of being an initial state.
                - ``.obs[{probs_key!r}]`` - top ``n_cells`` from each initial state.
        """

        if not isinstance(n_cells, int):
            raise TypeError(
                f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__!r}`."
            )

        if n_cells <= 0:
            raise ValueError(
                f"Expected `n_cells` to be positive, found `{n_cells}`.")

        probs = self._get(P.MACRO_MEMBER)

        if probs is None:
            raise RuntimeError(
                "Compute macrostates first as `.compute_macrostates()`.")
        elif probs.shape[1] == 1:
            categorical = self._create_states(probs, n_cells=n_cells)
            scaled = probs / probs.max()
        else:
            if names is None:
                names = probs.names
            if isinstance(names, str):
                names = [names]

            probs = probs[list(names)]
            categorical = self._create_states(probs, n_cells=n_cells)
            probs /= probs.max(0)

            # compute the aggregated probability of being a initial/terminal state (no matter which)
            scaled = probs.X.max(1)

        self._write_initial_states(membership=probs,
                                   probs=scaled,
                                   cats=categorical)

    def _write_initial_states(self,
                              membership: Lineage,
                              probs: pd.Series,
                              cats: pd.Series,
                              time=None) -> None:
        key = TermStatesKey.BACKWARD.s

        self.adata.obs[key] = cats
        self.adata.obs[_probs(key)] = probs

        self.adata.uns[_colors(key)] = membership.colors
        self.adata.uns[_lin_names(key)] = membership.names

        logg.info(
            f"Adding `adata.obs[{_probs(key)!r}]`\n       `adata.obs[{key!r}]`\n",
            time=time,
        )

    def _write_terminal_states(self, time=None) -> None:
        super()._write_terminal_states(time=time)

        term_abs_probs = self._get(A.TERM_ABS_PROBS)
        if term_abs_probs is None:
            # possibly remove previous value if it's inconsistent
            term_abs_probs = self.adata.obsm.get(self._term_abs_prob_key, None)

        if term_abs_probs is not None:
            new = list(self._get(P.TERM).cat.categories)
            old = list(term_abs_probs.names)
            if term_abs_probs.shape[1] == len(new) and new == old:
                self.adata.obsm[self._term_abs_prob_key] = term_abs_probs
            else:
                logg.warning(
                    f"Removing previously computed `adata.obsm[{self._term_abs_prob_key!r}]` because the "
                    f"names mismatch `{new}` (new), `{old}` (old).")

                self._set(A.TERM_ABS_PROBS, None)
                self.adata.obsm.pop(self._term_abs_prob_key, None)
Beispiel #11
0
class GPCCA(BaseEstimator, MetaStates, Schur, Eigen):
    """
    Generalized Perron Cluster Cluster Analysis [GPCCA18]_.

    Parameters
    ----------
    %(base_estimator.parameters)s
    """

    __prop_metadata__ = [
        Metadata(
            attr=A.COARSE_T,
            prop=P.COARSE_T,
            compute_fmt=F.NO_FUNC,
            plot_fmt=F.NO_FUNC,
            dtype=pd.DataFrame,
            doc="Coarse-grained transition matrix.",
        ),
        Metadata(attr=A.FIN_ABS_PROBS, prop=P.NO_PROPERTY, dtype=Lineage),
        Metadata(attr=A.COARSE_INIT_D, prop=P.COARSE_INIT_D, dtype=pd.Series),
        Metadata(attr=A.COARSE_STAT_D, prop=P.COARSE_STAT_D, dtype=pd.Series),
    ]

    def _read_from_adata(self) -> None:
        super()._read_from_adata()
        self._reconstruct_lineage(
            A.FIN_ABS_PROBS,
            self._fin_abs_prob_key,
        )

    @inject_docs(
        ms=P.META,
        msp=P.META_PROBS,
        schur=P.SCHUR.s,
        coarse_T=P.COARSE_T,
        coarse_stat=P.COARSE_STAT_D,
    )
    @d.dedent
    def compute_metastable_states(
        self,
        n_states: Optional[Union[int, Tuple[int, int], List[int],
                                 Dict[str, int]]] = None,
        n_cells: Optional[int] = 30,
        use_min_chi: bool = False,
        cluster_key: str = None,
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
    ):
        """
        Compute the metastable states.

        Parameters
        ----------
        n_states
            Number of metastable states. If `None`, use the `eigengap` heuristic.
        %(n_cells)s
        use_min_chi
            Whether to use :meth:`msmtools.analysis.dense.gpcca.GPCCA.minChi` to calculate the number of metastable
            states. If `True`, ``n_states`` corresponds to an interval `[min, max]` inside of which
            the potentially optimal number of metastable states is searched.
        cluster_key
            If a key to cluster labels is given, names and colors of the states will be associated with the clusters.
        en_cutoff
            If ``cluster_key`` is given, this parameter determines when an approximate recurrent class will
            be labelled as *'Unknown'*, based on the entropy of the distribution of cells over transcriptomic clusters.
        p_thresh
            If cell cycle scores were provided, a *Wilcoxon rank-sum test* is conducted to identify cell-cycle driven
            start- or endpoints.
            If the test returns a positive statistic and a p-value smaller than ``p_thresh``, a warning will be issued.

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :paramref:`{msp}`
                - :paramref:`{ms}`
                - :paramref:`{schur}`
                - :paramref:`{coarse_T}`
                - :paramref:`{coarse_stat}`
        """

        was_from_eigengap = False

        if use_min_chi:
            n_states = self._get_n_states_from_minchi(n_states)

        if n_states is None:
            if self._get(P.EIG) is None:
                raise RuntimeError(
                    "Compute eigendecomposition first as `.compute_eigendecomposition()` or `.compute_schur()`."
                )
            was_from_eigengap = True
            n_states = self._get(P.EIG)["eigengap"] + 1
            logg.info(f"Using `{n_states}` states based on eigengap")
        elif not isinstance(n_states, int):
            raise ValueError(
                f"Expected `n_states` to be an integer when `use_min_chi=False`, "
                f"found `{type(n_states).__name__!r}`.")

        if n_states <= 0:
            raise ValueError(
                f"Expected `n_states` to be positive or `None`, found `{n_states}`."
            )

        n_states = self._check_states_validity(n_states)
        if n_states == 1:
            self._compute_meta_for_one_state(
                n_cells=n_cells,
                cluster_key=cluster_key,
                p_thresh=p_thresh,
                en_cutoff=en_cutoff,
            )
            return

        if self._gpcca is None:
            if not was_from_eigengap:
                raise RuntimeError(
                    "Compute Schur decomposition first as `.compute_schur()`.")

            logg.warning(
                f"Number of states `{n_states}` was automatically determined by `eigengap` "
                "but no Schur decomposition was found. Computing with default parameters"
            )
            # this cannot fail if splitting occurs
            # if it were to split, it's automatically increased in `compute_schur`
            self.compute_schur(n_states + 1)

        if self._gpcca.X.shape[1] < n_states:
            logg.warning(
                f"Requested more metastable states `{n_states}` than available "
                f"Schur vectors `{self._gpcca.X.shape[1]}`. Recomputing the decomposition"
            )

        start = logg.info(f"Computing `{n_states}` metastable states")
        try:
            self._gpcca = self._gpcca.optimize(m=n_states)
        except ValueError as e:
            # this is the following cage - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev.
            # in the try block, schur decomposition with 5 vectors is computed, but it fails (no way of knowing)
            # so in this case, we increate it by 1
            n_states += 1
            logg.warning(f"{e}\nIncreasing `n_states` to `{n_states}`")
            self._gpcca = self._gpcca.optimize(m=n_states)

        self._set_meta_states(
            memberships=self._gpcca.memberships,
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )

        # cache the results and make sure we don't overwrite
        self._set(A.SCHUR, self._gpcca.X)
        self._set(A.SCHUR_MAT, self._gpcca.R)

        names = self._get(P.META_PROBS).names

        self._set(
            A.COARSE_T,
            pd.DataFrame(
                self._gpcca.coarse_grained_transition_matrix,
                index=names,
                columns=names,
            ),
        )
        self._set(
            A.COARSE_INIT_D,
            pd.Series(self._gpcca.coarse_grained_input_distribution,
                      index=names),
        )

        # careful here, in case computing the stat. dist failed
        if self._gpcca.coarse_grained_stationary_probability is not None:
            self._set(
                A.COARSE_STAT_D,
                pd.Series(
                    self._gpcca.coarse_grained_stationary_probability,
                    index=names,
                ),
            )
            logg.info(
                f"Adding `.{P.META_PROBS}`\n"
                f"       `.{P.META}`\n"
                f"       `.{P.SCHUR}`\n"
                f"       `.{P.COARSE_T}`\n"
                f"       `.{P.COARSE_STAT_D}`\n"
                f"    Finish",
                time=start,
            )
        else:
            logg.warning("No stationary distribution found in GPCCA object")
            logg.info(
                f"Adding `.{P.META_PROBS}`\n"
                f"       `.{P.META}`\n"
                f"       `.{P.SCHUR}`\n"
                f"       `.{P.COARSE_T}`\n"
                f"    Finish",
                time=start,
            )

    @d.dedent
    @inject_docs(fs=P.FIN, fsp=P.FIN_PROBS)
    def set_final_states_from_metastable_states(
        self,
        names: Optional[Union[Iterable[str], str]] = None,
        n_cells: int = 30,
    ):
        """
        Manually select the main states from the metastable states.

        Parameters
        ----------
        names
            Names of the main states. Multiple states can be combined using `','`, such as `['Alpha, Beta', 'Epsilon']`.
        %(n_cells)s

        Returns
        -------
        None
            Nothing, just updates the following fields:

                - :paramref:`{fsp}`
                - :paramref:`{fs}`
        """

        if not isinstance(n_cells, int):
            raise TypeError(
                f"Expected `n_cells` to be of type `int`, found `{type(n_cells).__name__}`."
            )

        if n_cells <= 0:
            raise ValueError(
                f"Expected `n_cells` to be positive, found `{n_cells}`.")

        probs = self._get(P.META_PROBS)

        if self._get(P.META_PROBS) is None:
            raise RuntimeError(
                "Compute metastable_states first as `.compute_metastable_states()`."
            )
        elif probs.shape[1] == 1:
            self._set(A.FIN, self._create_states(probs, n_cells=n_cells))
            self._set(A.FIN_COLORS, self._get(A.META_COLORS))
            self._set(A.FIN_PROBS, probs / probs.max())
            self._set(A.FIN_ABS_PROBS, probs)
            self._write_final_states()
            return

        if names is None:
            names = probs.names

        if isinstance(names, str):
            names = [names]

        meta_states_probs = probs[list(names)]

        # compute the aggregated probability of being a initial/terminal state (no matter which)
        scaled_probs = meta_states_probs[[
            n for n in meta_states_probs.names if n != "rest"
        ]].copy()
        scaled_probs /= scaled_probs.max(0)

        self._set(A.FIN, self._create_states(meta_states_probs, n_cells))
        self._set(A.FIN_PROBS,
                  pd.Series(scaled_probs.X.max(1), index=self.adata.obs_names))
        self._set(
            A.FIN_COLORS,
            meta_states_probs[list(self._get(P.FIN).cat.categories)].colors,
        )

        self._set(A.FIN_ABS_PROBS, scaled_probs)
        self._write_final_states()

    @inject_docs(fs=P.FIN, fsp=P.FIN_PROBS)
    @d.dedent
    def compute_final_states(
        self,
        method: str = "eigengap",
        n_cells: int = 30,
        alpha: Optional[float] = 1,
        min_self_prob: Optional[float] = None,
        n_final_states: Optional[int] = None,
    ):
        """
        Automatically select the main states from metastable states.

        Parameters
        ----------
        method
            One of following:

                - `'eigengap'` - select the number of states based on the eigengap of the transition matrix.
                - `'eigengap_coarse'` - select the number of states based on the eigengap of the diagonal
                    of the coarse-grained transition matrix.
                - `'top_n'` - select top ``n_final_states`` based on the probability of the diagonal \
                    of the coarse-grained transition matrix.
                - `'min_self_prob'` - select states which have the given minimum probability of the diagonal
                    of the coarse-grained transition matrix.
        %(n_cells)s
        alpha
            Weight given to the deviation of an eigenvalue from one. Used when ``method='eigengap'``
            or ``method='eigengap_coarse'``.
        min_self_prob
            Used when ``method='min_self_prob'``.
        n_final_states
            Used when ``method='top_n'``.

        Returns
        -------
        None
            Nothing, just updates the following fields:

                - :paramref:`{fsp}`
                - :paramref:`{fs}`
        """  # noqa

        if len(self._get(P.META).cat.categories) == 1:
            logg.warning(
                "Found only one metastable state. Making it the single main state"
            )
            self.set_final_states_from_metastable_states(None, n_cells=n_cells)
            return

        coarse_T = self._get(P.COARSE_T)

        if method == "eigengap":
            if self._get(P.EIG) is None:
                raise RuntimeError(
                    "Compute eigendecomposition first as `.compute_eigendecomposition()`."
                )
            n_final_states = _eigengap(self._get(P.EIG)["D"], alpha=alpha) + 1
        elif method == "eigengap_coarse":
            if coarse_T is None:
                raise RuntimeError(
                    "Compute metastable states first as `.compute_metastable_states()`."
                )
            n_final_states = _eigengap(np.sort(np.diag(coarse_T)[::-1]),
                                       alpha=alpha)
        elif method == "top_n":
            if n_final_states is None:
                raise ValueError(
                    "Argument `n_final_states` must be != `None` for `method='top_n'`."
                )
            elif n_final_states <= 0:
                raise ValueError(
                    f"Expected `n_final_states` to be positive, found `{n_final_states}`."
                )
        elif method == "min_self_prob":
            if min_self_prob is None:
                raise ValueError(
                    "Argument `min_self_prob` must be != `None` for `method='min_self_prob'`."
                )
            self_probs = pd.Series(np.diag(coarse_T), index=coarse_T.columns)
            names = self_probs[self_probs.values >= min_self_prob].index
            self.set_final_states_from_metastable_states(names,
                                                         n_cells=n_cells)
            return
        else:
            raise ValueError(
                f"Invalid method `{method!r}`. Valid options are `'eigengap', 'eigengap_coarse', "
                f"'top_n' and 'min_self_prob'`.")

        names = coarse_T.columns[np.argsort(
            np.diag(coarse_T))][-n_final_states:]
        self.set_final_states_from_metastable_states(names, n_cells=n_cells)

    def compute_gdpt(self,
                     n_components: int = 10,
                     key_added: str = "gdpt_pseudotime",
                     **kwargs):
        """
        Compute generalized Diffusion pseudotime from [Haghverdi16]_ making use of the real Schur decomposition.

        Parameters
        ----------
        n_components
            Number of real Schur vectors to consider.
        key_added
            Key in :paramref:`adata` ``.obs`` where to save the pseudotime.
        **kwargs
            Keyword arguments for :meth:`cellrank.tl.GPCCA.compute_schur` if Schur decomposition is not found.

        Returns
        -------
        None
            Nothing, just updates :paramref:`adata` ``.obs[key_added]`` with the computed pseudotime.
        """
        def _get_dpt_row(e_vals: np.ndarray, e_vecs: np.ndarray, i: int):
            row = sum(
                (np.abs(e_vals[eval_ix]) / (1 - np.abs(e_vals[eval_ix])) *
                 (e_vecs[i, eval_ix] - e_vecs[:, eval_ix]))**2
                # account for float32 precision
                for eval_ix in range(0, e_vals.size)
                if np.abs(e_vals[eval_ix]) < 0.9994)

            return np.sqrt(row)

        if "iroot" not in self.adata.uns.keys():
            raise KeyError("Key `'iroot'` not found in `adata.uns`.")

        iroot = self.adata.uns["iroot"]
        if isinstance(iroot, str):
            iroot = np.where(self.adata.obs_names == iroot)[0]
            if not len(iroot):
                raise ValueError(
                    f"Unable to find cell with name `{self.adata.uns['iroot']!r}` in `adata.obs_names`."
                )
            iroot = iroot[0]

        if n_components < 2:
            raise ValueError(
                f"Expected number of components >= 2, found `{n_components}`.")

        if self._get(P.SCHUR) is None:
            logg.warning("No Schur decomposition found. Computing")
            self.compute_schur(n_components, **kwargs)
        elif self._get(P.SCHUR_MAT).shape[1] < n_components:
            logg.warning(
                f"Requested `{n_components}` components, but only `{self._get(P.SCHUR_MAT).shape[1]}` were found. "
                f"Recomputing using default values")
            self.compute_schur(n_components)
        else:
            logg.debug("Using cached Schur decomposition")

        start = logg.info(
            f"Computing Generalized Diffusion Pseudotime using `n_components={n_components}`"
        )

        Q, eigenvalues = (
            self._get(P.SCHUR),
            self._get(P.EIG)["D"],
        )
        # may have to remove some values if too many converged
        Q, eigenvalues = Q[:, :n_components], eigenvalues[:n_components]
        D = _get_dpt_row(eigenvalues, Q, i=iroot)
        pseudotime = D / np.max(D[np.isfinite(D)])

        self.adata.obs[key_added] = pseudotime

        logg.info(f"Adding `{key_added!r}` to `adata.obs`\n    Finish",
                  time=start)

    @d.dedent
    def plot_coarse_T(
        self,
        show_stationary_dist: bool = True,
        show_initial_dist: bool = False,
        cmap: Union[str, mcolors.ListedColormap] = "viridis",
        xtick_rotation: float = 45,
        annotate: bool = True,
        show_cbar: bool = True,
        title: Optional[str] = None,
        figsize: Tuple[float, float] = (8, 8),
        dpi: int = 80,
        save: Optional[Union[Path, str]] = None,
        text_kwargs: Mapping[str, Any] = MappingProxyType({}),
        **kwargs,
    ) -> None:
        """
        Plot the coarse-grained transition matrix between metastable states.

        Parameters
        ----------
        show_stationary_dist
            Whether to show the stationary distribution, if present.
        show_initial_dist
            Whether to show the initial distribution.
        cmap
            Colormap to use.
        xtick_rotation
            Rotation of ticks on the x-axis.
        annotate
            Whether to display the text on each cell.
        show_cbar
            Whether to show colorbar.
        title
            Title of the figure.
        %(plotting)s
        text_kwargs
            Keyword arguments for :func:`matplotlib.pyplot.text`.
        **kwargs
            Keyword arguments for :func:`matplotlib.pyplot.imshow`.

        Returns
        -------
        %(just_plots)s
        """
        def stylize_dist(ax,
                         data: np.ndarray,
                         xticks_labels: Union[List[str], Tuple[str]] = ()):
            _ = ax.imshow(data, aspect="auto", cmap=cmap, norm=norm)
            for spine in ax.spines.values():
                spine.set_visible(False)

            if xticks_labels is not None:
                ax.set_xticklabels(xticks_labels)
                ax.set_xticks(np.arange(data.shape[1]))
                plt.setp(
                    ax.get_xticklabels(),
                    rotation=xtick_rotation,
                    ha="right",
                    rotation_mode="anchor",
                )
            else:
                ax.set_xticks([])
                ax.tick_params(which="both",
                               top=False,
                               right=False,
                               bottom=False,
                               left=False)

            ax.set_yticks([])

        def annotate_heatmap(im, valfmt: str = "{x:.2f}"):
            # modified from matplotlib's site

            data = im.get_array()
            kw = {"ha": "center", "va": "center"}
            kw.update(**text_kwargs)

            # Get the formatter in case a string is supplied
            if isinstance(valfmt, str):
                valfmt = mpl.ticker.StrMethodFormatter(valfmt)

            # Loop over the data and create a `Text` for each "pixel".
            # Change the text's color depending on the data.
            texts = []
            for i in range(data.shape[0]):
                for j in range(data.shape[1]):
                    kw.update(
                        color=_get_black_or_white(im.norm(data[i, j]), cmap))
                    text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
                    texts.append(text)

        def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"):
            if ax is None:
                return

            if isinstance(valfmt, str):
                valfmt = mpl.ticker.StrMethodFormatter(valfmt)

            kw = {"ha": "center", "va": "center"}
            kw.update(**text_kwargs)

            for i, val in enumerate(data):
                kw.update(color=_get_black_or_white(im.norm(val), cmap))
                ax.text(
                    i,
                    0,
                    valfmt(val, None),
                    **kw,
                )

        coarse_T = self._get(P.COARSE_T)
        coarse_stat_d = self._get(P.COARSE_STAT_D)
        coarse_init_d = self._get(P.COARSE_INIT_D)

        if coarse_T is None:
            raise RuntimeError(
                "Compute coarse-grained transition matrix first as `.compute_metastable_states()` with `n_states > 1`."
            )

        if show_stationary_dist and coarse_stat_d is None:
            logg.warning("Coarse stationary distribution is `None`, ignoring")
            show_stationary_dist = False
        if show_initial_dist and coarse_init_d is None:
            logg.warning("Coarse initial distribution is `None`, ignoring")
            show_initial_dist = False

        hrs, wrs = [1], [1]
        if show_stationary_dist:
            hrs += [0.05]
        if show_initial_dist:
            hrs += [0.05]
        if show_cbar:
            wrs += [0.025]

        dont_show_dist = not show_initial_dist and not show_stationary_dist

        fig = plt.figure(constrained_layout=False, figsize=figsize, dpi=dpi)
        gs = plt.GridSpec(
            1 + show_stationary_dist + show_initial_dist,
            1 + show_cbar,
            height_ratios=hrs,
            width_ratios=wrs,
            wspace=0.05,
            hspace=0.05,
        )
        if isinstance(cmap, str):
            cmap = plt.get_cmap(cmap)

        ax = fig.add_subplot(gs[0, 0])
        cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None
        init_ax, stat_ax = None, None

        labels = list(self.coarse_T.columns)

        tmp = coarse_T
        if show_initial_dist:
            tmp = np.c_[tmp, coarse_stat_d]
        if show_initial_dist:
            tmp = np.c_[tmp, coarse_init_d]

        minn, maxx = np.nanmin(tmp), np.nanmax(tmp)
        norm = mpl.colors.Normalize(vmin=minn, vmax=maxx)

        if show_stationary_dist:
            stat_ax = fig.add_subplot(gs[1, 0])
            stylize_dist(
                stat_ax,
                np.array(coarse_stat_d).reshape(1, -1),
                xticks_labels=labels if not show_initial_dist else None,
            )
            stat_ax.yaxis.set_label_position("right")
            stat_ax.set_ylabel("stationary dist",
                               rotation=0,
                               ha="left",
                               va="center")

        if show_initial_dist:
            init_ax = fig.add_subplot(gs[show_stationary_dist +
                                         show_initial_dist, 0])
            stylize_dist(init_ax,
                         np.array(coarse_init_d).reshape(1, -1),
                         xticks_labels=labels)

            init_ax.yaxis.set_label_position("right")
            init_ax.set_ylabel("initial dist",
                               rotation=0,
                               ha="left",
                               va="center")

        im = ax.imshow(coarse_T, aspect="auto", cmap=cmap, norm=norm, **kwargs)
        ax.set_title(
            "coarse-grained transition matrix" if title is None else title)

        if cax is not None:
            _ = mpl.colorbar.ColorbarBase(
                cax,
                cmap=cmap,
                norm=norm,
                ticks=np.linspace(minn, maxx, 10),
                format="%0.3f",
            )

        ax.set_yticks(np.arange(coarse_T.shape[0]))
        ax.set_yticklabels(labels)

        ax.tick_params(
            top=False,
            bottom=dont_show_dist,
            labeltop=False,
            labelbottom=dont_show_dist,
        )

        for spine in ax.spines.values():
            spine.set_visible(False)

        if dont_show_dist:
            ax.set_xticks(np.arange(coarse_T.shape[1]))
            ax.set_xticklabels(labels)
            plt.setp(
                ax.get_xticklabels(),
                rotation=xtick_rotation,
                ha="right",
                rotation_mode="anchor",
            )
        else:
            ax.set_xticks([])

        ax.set_yticks(np.arange(coarse_T.shape[0] + 1) - 0.5, minor=True)
        ax.tick_params(which="minor",
                       bottom=dont_show_dist,
                       left=False,
                       top=False)

        if annotate:
            annotate_heatmap(im)
            annotate_dist_ax(stat_ax, coarse_stat_d.values)
            annotate_dist_ax(init_ax, coarse_init_d)

        if save:
            save_fig(fig, save)

        fig.show()

    def _compute_meta_for_one_state(
        self,
        n_cells: int,
        cluster_key: Optional[str],
        en_cutoff: Optional[float],
        p_thresh: float,
    ) -> None:
        start = logg.info("Computing metastable states")
        logg.warning("For `n_states=1`, stationary distribution is computed")

        eig = self._get(P.EIG)
        if (eig is not None and "stationary_dist" in eig
                and eig["params"]["which"] == "LR"):
            stationary_dist = eig["stationary_dist"]
        else:
            self.compute_eigendecomposition(only_evals=False, which="LR")
            stationary_dist = self._get(P.EIG)["stationary_dist"]

        self._set_meta_states(
            memberships=stationary_dist[:, None],
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )
        self._set(
            A.META_PROBS,
            Lineage(
                stationary_dist,
                names=list(self._get(A.META).cat.categories),
                colors=self._get(A.META_COLORS),
            ),
        )

        # reset all the things
        for key in (
                A.ABS_PROBS,
                A.SCHUR,
                A.SCHUR_MAT,
                A.COARSE_T,
                A.COARSE_STAT_D,
                A.COARSE_STAT_D,
        ):
            self._set(key.s, None)

        logg.info(
            f"Adding `.{P.META_PROBS}`\n        `.{P.META}`\n    Finish",
            time=start,
        )

    def _get_n_states_from_minchi(
            self, n_states: Union[Tuple[int, int], List[int],
                                  Dict[str, int]]) -> int:
        if self._gpcca is None:
            raise RuntimeError(
                "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`."
            )

        if not isinstance(n_states, (dict, tuple, list)):
            raise TypeError(
                f"Expected `n_states` to be either `dict`, `tuple` or a `list`, "
                f"found `{type(n_states).__name__}`.")
        if len(n_states) != 2:
            raise ValueError(
                f"Expected `n_states` to be of size `2`, found `{len(n_states)}`."
            )

        if isinstance(n_states, dict):
            if "min" not in n_states or "max" not in n_states:
                raise KeyError(
                    f"Expected the dictionary to have `'min'` and `'max'` keys, "
                    f"found `{tuple(n_states.keys())}`.")
            minn, maxx = n_states["min"], n_states["max"]
        else:
            minn, maxx = n_states

        if minn > maxx:
            logg.debug(
                f"Swapping minimum and maximum because `{minn}` > `{maxx}`")
            minn, maxx = maxx, minn

        if minn <= 1:
            raise ValueError(f"Minimum value must be > `1`, found `{minn}`.")
        elif minn == 2:
            logg.warning(
                "In most cases, 2 clusters will always be optimal. "
                "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`"
            )
            minn = 3

        if minn >= maxx:
            maxx = minn + 1
            logg.debug(
                f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`"
            )

        logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`")

        return int(
            np.arange(minn, maxx)[np.argmax(self._gpcca.minChi(minn, maxx))])

    @d.dedent
    def _set_meta_states(
        self,
        memberships: np.ndarray,
        n_cells: Optional[int] = 30,
        cluster_key: str = "clusters",
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
        check_row_sums: bool = True,
    ) -> None:
        """
        Map a fuzzy clustering to pre-computed annotations to get names and colors.

        Given the fuzzy clustering we have computed, we would like to select the most likely cells from each state
        and use these to give each state a name and a color by comparing with pre-computed, categorical cluster
        annotations.

        Parameters
        --------
        memberships
            Fuzzy clustering.
        %(n_cells)s
        cluster_key
            Key from :paramref:`adata` ``.obs`` to get reference cluster annotations.
        en_cutoff
            Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when
            one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost
            evenly across the reference clusters.
        p_thresh
            Only used to detect cell cycle stages. These have to be present in
            :paramref:`adata` ``.obs`` as `'G2M_score'` and `'S_score'`.
        check_row_sums
            Check whether rows in `memberships` sum to `1`.

        Returns
        -------
        None
            Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors.
            Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state.
        """

        if n_cells is None:
            logg.debug(
                "Setting the metastable states using metastable assignment")

            max_assignment = np.argmax(memberships, axis=1)
            _meta_assignment = pd.Series(index=self.adata.obs_names,
                                         data=max_assignment,
                                         dtype="category")
            # sometimes, the assignment can have a missing category and the Lineage creation therefore fails
            # keep it as ints when `n_cells != None`
            _meta_assignment.cat.set_categories(list(
                range(memberships.shape[1])),
                                                inplace=True)

            metastable_states = _meta_assignment.astype(str).astype(
                "category").copy()
            not_enough_cells = []
        else:
            logg.debug(
                "Setting the metastable states using metastable memberships")

            # select the most likely cells from each metastable state
            metastable_states, not_enough_cells = self._create_states(
                memberships,
                n_cells=n_cells,
                check_row_sums=check_row_sums,
                return_not_enough_cells=True,
            )
            not_enough_cells = not_enough_cells.astype("str")

        # _set_categorical_labels creates the names, we still need to remap the group names
        orig_cats = metastable_states.cat.categories
        self._set_categorical_labels(
            attr_key=A.META.v,
            color_key=A.META_COLORS.v,
            pretty_attr_key=P.META.v,
            add_to_existing_error_msg=
            "Compute metastable states first as `.compute_metastable_states()`.",
            categories=metastable_states,
            cluster_key=cluster_key,
            en_cutoff=en_cutoff,
            p_thresh=p_thresh,
            add_to_existing=False,
        )

        name_mapper = dict(zip(orig_cats, self._get(P.META).cat.categories))
        _print_insufficient_number_of_cells(
            [name_mapper.get(n, n) for n in not_enough_cells], n_cells)

        logg.debug(
            "Setting metastable lineage probabilities based on GPCCA membership vectors"
        )

        self._set(
            A.META_PROBS,
            Lineage(
                memberships,
                names=list(metastable_states.cat.categories),
                colors=self._get(A.META_COLORS),
            ),
        )

    def _create_states(
        self,
        probs: Union[np.ndarray, Lineage],
        n_cells: int,
        check_row_sums: bool = False,
        return_not_enough_cells: bool = False,
    ) -> pd.Series:
        if n_cells <= 0:
            raise ValueError(
                f"Expected `n_cells` to be positive, found `{n_cells}`.")

        if isinstance(probs, Lineage):
            probs = probs[[n for n in probs.names if n != "rest"]]

        a_discrete, not_enough_cells = _fuzzy_to_discrete(
            a_fuzzy=probs,
            n_most_likely=n_cells,
            remove_overlap=False,
            raise_threshold=0.2,
            check_row_sums=check_row_sums,
        )

        states = _series_from_one_hot_matrix(
            membership=a_discrete,
            index=self.adata.obs_names,
            names=probs.names if isinstance(probs, Lineage) else None,
        )

        return (states,
                not_enough_cells) if return_not_enough_cells else states

    def _check_states_validity(self, n_states: int) -> int:
        if self._invalid_n_states is not None and n_states in self._invalid_n_states:
            logg.warning(
                f"Unable to compute metastable states with `n_states={n_states}` because it will "
                f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`"
            )
            n_states += 1  # cannot force recomputation of Schur decomposition
            assert n_states not in self._invalid_n_states, "Sanity check failed."

        return n_states

    def _fit_final_states(
        self,
        n_lineages: Optional[int] = None,
        cluster_key: Optional[str] = None,
        method: str = "krylov",
        **kwargs,
    ) -> None:
        if n_lineages is None or n_lineages == 1:
            self.compute_eigendecomposition()
            if n_lineages is None:
                n_lineages = self.eigendecomposition["eigengap"] + 1

        if n_lineages > 1:
            self.compute_schur(n_lineages + 1, method=method)

        try:
            self.compute_metastable_states(n_states=n_lineages,
                                           cluster_key=cluster_key,
                                           **kwargs)
        except ValueError:
            logg.warning(
                f"Computing `{n_lineages}` metastable states cuts through a block of complex conjugates. "
                f"Increasing `n_lineages` to {n_lineages + 1}")
            self.compute_metastable_states(n_states=n_lineages + 1,
                                           cluster_key=cluster_key,
                                           **kwargs)

        fs_kwargs = {
            "n_cells": kwargs["n_cells"]
        } if "n_cells" in kwargs else {}

        if n_lineages is None:
            self.compute_final_states(method="eigengap", **fs_kwargs)
        else:
            self.set_final_states_from_metastable_states(**fs_kwargs)

    @d.dedent  # because of fit
    @d.dedent
    @inject_docs(
        ms=P.META,
        msp=P.META_PROBS,
        fs=P.FIN,
        fsp=P.FIN_PROBS,
        ap=P.ABS_PROBS,
        dp=P.DIFF_POT,
    )
    def fit(
        self,
        n_lineages: Optional[int] = None,
        cluster_key: Optional[str] = None,
        keys: Optional[Sequence[str]] = None,
        method: str = "krylov",
        compute_absorption_probabilities: bool = True,
        **kwargs,
    ):
        """
        Run the pipeline, computing the metastable states, %(final)s states and optionally the absorption probabilities.

        It is equivalent to running::

            if n_lineages is None or n_lineages == 1:
                compute_eigendecomposition(...)  # get the stationary distribution
            if n_lineages > 1:
                compute_schur(...)

            compute_metastable_states(...)

            if n_lineages is None:
                compute_final_states(...)
            else:
                set_final_states_from_metastable_states(...)

            if compute_absorption_probabilities:
                compute_absorption_probabilities(...)

        Parameters
        ----------
        %(fit)s
        method
            Method to use when computing the Schur decomposition. Valid options are: `'krylov'` or `'brandts'`.
        compute_absorption_probabilities
            Whether to compute absorption probabilities or only final states.
        **kwargs
            Keyword arguments for :meth:`cellrank.tl.estimators.GPCCA.compute_metastable_states`.

        Returns
        -------
        None
            Nothing, just makes available the following fields:

                - :paramref:`{msp}`
                - :paramref:`{ms}`
                - :paramref:`{fsp}`
                - :paramref:`{fs}`
                - :paramref:`{ap}`
                - :paramref:`{dp}`
        """

        super().fit(
            n_lineages=n_lineages,
            cluster_key=cluster_key,
            keys=keys,
            method=method,
            compute_absorption_probabilities=compute_absorption_probabilities,
            **kwargs,
        )
Beispiel #12
0
    def __new__(cls, clsname, superclasses, attributedict):
        """
        Create a new instance.

        Parameters
        ----------
        clsname
            Name of class to be constructed.
        superclasses
            List of superclasses.
        attributedict
            Dictionary of attributes.
        """

        compute_md, metadata = attributedict.pop(META_KEY, None), []

        if compute_md is None:
            return super().__new__(cls, clsname, superclasses, attributedict)

        if isinstance(compute_md, str):
            compute_md = Metadata(attr=compute_md)
        elif not isinstance(compute_md, (tuple, list)):
            raise TypeError(
                f"Expected property metadata to be `list` or `tuple`,"
                f"found `{type(compute_md).__name__!r}`.")
        elif len(compute_md) == 0:
            raise ValueError("No metadata found.")
        else:
            compute_md, *metadata = (Metadata(
                attr=md) if isinstance(md, str) else md for md in compute_md)

        prop_name = PropertyMeta.update_attributes(compute_md, attributedict)
        plot_name = str(compute_md.plot_fmt).format(prop_name)

        if compute_md.compute_fmt != F.NO_FUNC:
            if "_compute" in attributedict:
                attributedict[str(compute_md.compute_fmt).format(
                    prop_name)] = attributedict["_compute"]

        if (compute_md.plot_fmt != F.NO_FUNC
                and VectorPlottable in superclasses
                and plot_name not in attributedict
                and not is_abstract(clsname)):
            raise TypeError(
                f"Method `{plot_name}` is not implemented for class `{clsname}`."
            )

        for md in metadata:
            PropertyMeta.update_attributes(md, attributedict)

        res = super().__new__(cls, clsname, superclasses, attributedict)

        if compute_md.plot_fmt != F.NO_FUNC and Plottable in res.mro():
            # _this is intended singledispatchmethod
            # unfortunately, `_plot` is not always in attributedict, so we can't just check for it
            # and res._plot is just a regular function
            # if this gets buggy in the future, consider switching from singlemethoddispatch
            setattr(
                res,
                plot_name,
                _delegate_method_dispatch(res._plot,
                                          "_plot",
                                          prop_name,
                                          skip=2),
            )

        return res