示例#1
0
        def get_scores(
            key: str,
            *,
            kind: Literal["proliferation", "apoptosis"],
            organism: Optional[Literal["human", "mouse"]],
        ) -> np.ndarray:
            try:
                return np.asarray(self.adata.obs[key])
            except KeyError:
                if organism is None:
                    raise KeyError(
                        f"Unable to find `{kind}` scores in `adata.obs[{kind!r}]`. Consider specifying `organism=...`."
                    ) from None

                logg.info(f"Computing `{kind}` scores")
                score_name = f"{kind}_score" if key is None else key

                sc.tl.score_genes(
                    self.adata,
                    gene_list=getattr(MarkerGenes, (f"{kind}_markers"))(organism),
                    score_name=score_name,
                    **kwargs,
                )

                return get_scores(score_name, kind=kind, organism=None)
示例#2
0
    def _write_terminal_states(self, time=None) -> None:
        self.adata.obs[self._term_key] = self._get(P.TERM)
        self.adata.obs[_probs(self._term_key)] = self._get(P.TERM_PROBS)

        self.adata.uns[_colors(self._term_key)] = self._get(A.TERM_COLORS)
        self.adata.uns[_lin_names(self._term_key)] = np.array(
            self._get(P.TERM).cat.categories
        )

        extra_msg = ""
        if getattr(self, A.TERM_ABS_PROBS.s, None) is not None and hasattr(
            self, "_term_abs_prob_key"
        ):
            # checking for None because terminal states can be set using `set_terminal_states`
            # without the probabilities in GPCCA
            self.adata.obsm[self._term_abs_prob_key] = self._get(A.TERM_ABS_PROBS)
            extra_msg = f"       `adata.obsm[{self._term_abs_prob_key!r}]`\n"

        logg.info(
            f"Adding `adata.obs[{_probs(self._term_key)!r}]`\n"
            f"       `adata.obs[{self._term_key!r}]`\n"
            f"{extra_msg}"
            f"       `.{P.TERM_PROBS}`\n"
            f"       `.{P.TERM}`\n"
            "    Finish",
            time=time,
        )
示例#3
0
    def test_timing(self, monkeypatch, capsys, logging_state):
        import cellrank.logging._logging as logg

        settings.logfile = sys.stderr
        counter = 0

        class IncTime:
            @staticmethod
            def now(tz):
                nonlocal counter
                counter += 1
                return datetime(2000,
                                1,
                                1,
                                second=counter,
                                microsecond=counter,
                                tzinfo=tz)

        monkeypatch.setattr(logg, "datetime", IncTime)
        settings.verbosity = Verbosity.debug

        logg.hint("1")
        assert counter == 1 and capsys.readouterr().err == "--> 1\n"
        start = logg.info("2")
        assert counter == 2 and capsys.readouterr().err == "2\n"
        logg.hint("3")
        assert counter == 3 and capsys.readouterr().err == "--> 3\n"
        logg.info("4", time=start)
        assert counter == 4 and capsys.readouterr().err == "4 (0:00:02)\n"
        logg.info("5 {time_passed}", time=start)
        assert counter == 5 and capsys.readouterr().err == "5 0:00:03\n"
示例#4
0
    def plot(
        self,
        min_flow: float = 0,
        remove_empty_clusters: bool = True,
        ascending: Optional[bool] = False,
        alpha: float = 0.8,
        xticks_step_size: Optional[int] = 1,
        legend_loc: Optional[str] = "upper right out",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
    ) -> plt.Axes:
        """
        Plot outgoing flow.

        Parameters
        ----------
        min_flow
            Only show flow edges with flow greater than this value. Flow values are always in `[0, 1]`.
        remove_empty_clusters
            Whether to remove clusters with no incoming flow edges.
        ascending
            Whether to sort the cluster by ascending or descending incoming flow.
            If `None`, use the order as in defined by ``clusters``.
        alpha
            Alpha value for cell proportions.
        xticks_step_size
            Show only every n-th ticks on x-axis. If `None`, don't show any ticks.
        legend_loc
            Position of the legend. If `None`, do not show the legend.

        Returns
        -------
        The axes object.
        """
        if self._flow is None or self._cmat is None:
            raise RuntimeError(
                "Compute flow and contingency matrix first as `.prepare()`.")

        flow, cmat = self._flow, self._cmat
        try:
            if remove_empty_clusters:
                self._remove_min_clusters(min_flow)
            logg.info(
                f"Plotting flow from `{self._cluster}` into `{len(self._flow.columns) - 1}` cluster(s) "
                f"in `{len(self._cmat.columns) - 1}` time points")
            return self._plot(
                self._rename_times(),
                ascending=ascending,
                min_flow=min_flow,
                alpha=alpha,
                xticks_step_size=xticks_step_size,
                legend_loc=legend_loc,
                figsize=figsize,
                dpi=dpi,
            )
        finally:
            self._flow = flow
            self._cmat = cmat
示例#5
0
 def test_formats(self, capsys, logging_state):
     settings.logfile = sys.stderr
     settings.verbosity = Verbosity.debug
     logg.error("0")
     assert capsys.readouterr().err == "ERROR: 0\n"
     logg.warning("1")
     assert capsys.readouterr().err == "WARNING: 1\n"
     logg.info("2")
     assert capsys.readouterr().err == "2\n"
     logg.hint("3")
     assert capsys.readouterr().err == "--> 3\n"
示例#6
0
    def _reuse_cache(self,
                     expected_params: Dict[str, Any],
                     *,
                     time: Optional[Any] = None) -> bool:
        if expected_params == self._params:
            assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG
            logg.debug(_LOG_USING_CACHE)
            logg.info("    Finish", time=time)
            return True

        self._params = expected_params
        return False
示例#7
0
 def _maybe_compute_cond_num(self):
     if self._compute_cond_num and self._cond_num is None:
         logg.debug(f"Computing condition number of `{repr(self)}`")
         self._cond_num = np.linalg.cond(self._transition_matrix.toarray(
         ) if issparse(self._transition_matrix) else self._transition_matrix
                                         )
         if self._cond_num > _cond_num_tolerance:
             logg.warning(
                 f"`{repr(self)}` may be ill-conditioned, its condition number is `{self._cond_num:.2e}`"
             )
         else:
             logg.info(f"Condition number is `{self._cond_num:.2e}`")
示例#8
0
    def _write_eig_to_adata(
        self, eig: Mapping[str, Any], start=None, extra_msg: Optional[str] = None
    ):
        setattr(self, A.EIG.s, eig)
        self.adata.uns[f"eig_{self._direction}"] = eig

        msg = f"Adding `adata.uns['eig_{self._direction}']`\n       `.{P.EIG}`"
        if extra_msg is None:
            extra_msg = "\n    Finish"

        msg += extra_msg

        logg.info(msg, time=start)
示例#9
0
    def compute_lineage_priming(
        self,
        method: Literal["kl_divergence", "entropy"] = "kl_divergence",
        early_cells: Optional[Union[Mapping[str, Sequence[str]], Sequence[str]]] = None,
    ) -> pd.Series:
        """
        %(lin_pd.full_desc)s

        Parameters
        ----------
        %(lin_pd.parameters)s
            Cell ids or a mask marking early cells. If `None`, use all cells. Only used when ``method='kl_divergence'``.
            If a :class:`dict`, the key species a cluster key in :attr:`anndata.AnnData.obs` and the values
            specify cluster labels containing early cells.

        Returns
        -------
        %(lin_pd.returns)s
        """  # noqa: D400
        abs_probs: Optional[Lineage] = self._get(P.ABS_PROBS)
        if abs_probs is None:
            raise RuntimeError(
                "Compute absorption probabilities first as `.compute_absorption_probabilities()`."
            )
        if isinstance(early_cells, dict):
            if len(early_cells) != 1:
                raise ValueError(
                    f"Expected a dictionary with only 1 key, found `{len(early_cells)}`."
                )
            key = next(iter(early_cells.keys()))
            if key not in self.adata.obs:
                raise KeyError(f"Unable to find clustering in `adata.obs[{key!r}]`.")
            early_cells = self.adata.obs[key].isin(early_cells[key])
        elif early_cells is not None:
            early_cells = np.asarray(early_cells)
            if not np.issubdtype(early_cells.dtype, np.bool_):
                early_cells = np.isin(self.adata.obs_names, early_cells)

        values = pd.Series(
            abs_probs.priming_degree(method, early_cells), index=self.adata.obs_names
        )

        self._set(A.PRIME_DEG, values)
        self.adata.obs[_pd(self._abs_prob_key)] = values

        logg.info(
            f"Adding `adata.obs[{_pd(self._abs_prob_key)!r}]`\n"
            f"       `.{P.PRIME_DEG}`"
        )

        return values
示例#10
0
    def _get_n_states_from_minchi(
            self, n_states: Union[Tuple[int, int], List[int],
                                  Dict[str, int]]) -> int:
        if self._gpcca is None:
            raise RuntimeError(
                "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`."
            )

        if not isinstance(n_states, (dict, tuple, list)):
            raise TypeError(
                f"Expected `n_states` to be either `dict`, `tuple` or a `list`, "
                f"found `{type(n_states).__name__}`.")
        if len(n_states) != 2:
            raise ValueError(
                f"Expected `n_states` to be of size `2`, found `{len(n_states)}`."
            )

        if isinstance(n_states, dict):
            if "min" not in n_states or "max" not in n_states:
                raise KeyError(
                    f"Expected the dictionary to have `'min'` and `'max'` keys, "
                    f"found `{tuple(n_states.keys())}`.")
            minn, maxx = n_states["min"], n_states["max"]
        else:
            minn, maxx = n_states

        if minn > maxx:
            logg.debug(
                f"Swapping minimum and maximum because `{minn}` > `{maxx}`")
            minn, maxx = maxx, minn

        if minn <= 1:
            raise ValueError(f"Minimum value must be > `1`, found `{minn}`.")
        elif minn == 2:
            logg.warning(
                "In most cases, 2 clusters will always be optimal. "
                "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`"
            )
            minn = 3

        if minn >= maxx:
            maxx = minn + 1
            logg.debug(
                f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`"
            )

        logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`")

        return int(
            np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn,
                                                                   maxx))])
示例#11
0
    def _compute_one_macrostate(
        self,
        n_cells: int,
        cluster_key: Optional[str],
        en_cutoff: Optional[float],
        p_thresh: float,
    ) -> None:
        start = logg.warning(
            "For 1 macrostate, stationary distribution is computed")

        eig = self._get(P.EIG)
        if (eig is not None and "stationary_dist" in eig
                and eig["params"]["which"] == "LR"):
            stationary_dist = eig["stationary_dist"]
        else:
            self.compute_eigendecomposition(only_evals=False, which="LR")
            stationary_dist = self._get(P.EIG)["stationary_dist"]

        self._set_macrostates(
            memberships=stationary_dist[:, None],
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )
        self._set(
            A.MACRO_MEMBER,
            Lineage(
                stationary_dist,
                names=list(self._get(A.MACRO).cat.categories),
                colors=self._get(A.MACRO_COLORS),
            ),
        )

        # reset all the things
        for key in (
                A.ABS_PROBS,
                A.PRIME_DEG,
                A.SCHUR,
                A.SCHUR_MAT,
                A.COARSE_T,
                A.COARSE_STAT_D,
                A.COARSE_STAT_D,
        ):
            self._set(key.s, None)

        logg.info(
            f"Adding `.{P.MACRO_MEMBER}`\n        `.{P.MACRO}`\n    Finish",
            time=start,
        )
示例#12
0
    def simulate_many(
        self,
        n_sims: int,
        max_iter: Union[int, float] = 0.25,
        seed: Optional[int] = None,
        successive_hits: int = 0,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
    ) -> List[np.ndarray]:
        """
        Simulate many random walks.

        Parameters
        ----------
        n_sims
            Number of random walks to simulate.
        %(rw_sim.params)s
        %(parallel)s

        Returns
        -------
        List of arrays of shape ``(max_iter + 1,)`` of states that have been visited. If ``stop_ixs`` was specified,
        the arrays may have smaller shape.
        """
        if n_sims <= 0:
            raise ValueError(
                f"Expected number of simulations to be positive, found `{n_sims}`."
            )
        max_iter = self._max_iter(max_iter)
        start = logg.info(
            f"Simulating `{n_sims}` random walks of maximum length `{max_iter}`"
        )

        simss = parallelize(
            self._simulate_many,
            collection=np.arange(n_sims),
            n_jobs=n_jobs,
            backend=backend,
            show_progress_bar=show_progress_bar,
            as_array=False,
            unit="sim",
        )(max_iter=max_iter, seed=seed, successive_hits=successive_hits)
        simss = list(chain.from_iterable(simss))

        logg.info("    Finish", time=start)

        return simss
示例#13
0
    def _write_absorption_probabilities(
        self, time: datetime, extra_msg: str = ""
    ) -> None:
        self.adata.obsm[self._abs_prob_key] = self._get(P.ABS_PROBS)

        abs_prob = self._get(P.ABS_PROBS)

        self.adata.uns[_lin_names(self._abs_prob_key)] = abs_prob.names
        self.adata.uns[_colors(self._abs_prob_key)] = abs_prob.colors

        logg.info(
            f"Adding `adata.obsm[{self._abs_prob_key!r}]`\n"
            f"{extra_msg}"
            f"       `.{P.ABS_PROBS}`\n"
            "    Finish",
            time=time,
        )
示例#14
0
    def _write_initial_states(self,
                              membership: Lineage,
                              probs: pd.Series,
                              cats: pd.Series,
                              time=None) -> None:
        key = TermStatesKey.BACKWARD.s

        self.adata.obs[key] = cats
        self.adata.obs[_probs(key)] = probs

        self.adata.uns[_colors(key)] = membership.colors
        self.adata.uns[_lin_names(key)] = membership.names

        logg.info(
            f"Adding `adata.obs[{_probs(key)!r}]`\n       `adata.obs[{key!r}]`\n",
            time=time,
        )
示例#15
0
    def compute_partition(self) -> None:
        """
        Compute communication classes for the Markov chain.

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`recurrent_classes`
                - :attr:`transient_classes`
                - :attr:`is_irreducible`
        """

        start = logg.info("Computing communication classes")
        n_states = len(self)

        rec_classes, trans_classes = _partition(self.transition_matrix)

        self._is_irreducible = len(rec_classes) == 1 and len(
            trans_classes) == 0

        if not self._is_irreducible:
            self._trans_classes = _make_cat(trans_classes, n_states,
                                            self.adata.obs_names)
            self._rec_classes = _make_cat(rec_classes, n_states,
                                          self.adata.obs_names)
            logg.info(
                f"Found `{(len(rec_classes))}` recurrent and `{len(trans_classes)}` transient classes\n"
                f"Adding `.recurrent_classes`\n"
                f"       `.transient_classes`\n"
                f"       `.is_irreducible`\n"
                f"    Finish",
                time=start,
            )
        else:
            logg.warning(
                "The transition matrix is irreducible, cannot further partition it\n    Finish",
                time=start,
            )
示例#16
0
    def maybe_create_lineage(
        direction: Union[str, Direction], pretty_name: Optional[str] = None
    ):
        if isinstance(direction, Direction):
            lin_key = str(
                AbsProbKey.FORWARD
                if direction == Direction.FORWARD
                else AbsProbKey.BACKWARD
            )
        else:
            lin_key = direction

        pretty_name = "" if pretty_name is None else (pretty_name + " ")
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)

        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`")

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage names not found in `adata.uns[{names_key!r}]`, creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"    Lineage names are don't have the required length ({n_lineages}), creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("    Successfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                map(lambda c: is_color_like(c), adata.uns[colors_key])
            ):
                logg.warning(
                    f"    Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("    Successfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(
                adata.obsm[lin_key], names=names, colors=colors
            )
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`"
            )
示例#17
0
    def _compute_pairwise_tmaps(
        self,
        adata: AnnData,
        cost_matrices: Optional[
            Union[str, Mapping[Tuple[float, float], np.ndarray]]
        ] = None,
        solver: Literal["fixed_iters", "duality_gap"] = "duality_gap",
        growth_rate_field: Optional[str] = None,
        **kwargs,
    ) -> Dict[Tuple[float, float], AnnData]:
        self._ot_model = wot.ot.OTModel(
            adata,
            day_field=self._time_key,
            covariate_field=None,
            growth_rate_field=growth_rate_field,
            solver=solver,
            **kwargs,
        )

        self._tmaps: Dict[Tuple[float, float], AnnData] = {}
        start = logg.info(
            f"Computing transport maps for `{len(cost_matrices)}` time pairs"
        )
        for tpair, cost_matrix in tqdm(cost_matrices.items(), unit="time pair"):
            tmap: AnnData = self._ot_model.compute_transport_map(
                *tpair, cost_matrix=cost_matrix
            )
            tmap.X = tmap.X.astype(np.float64)
            nans = int(np.sum(~np.isfinite(tmap.X)))
            if nans:
                raise ValueError(
                    f"Encountered `{nans}` non-finite values for time pair `{tpair}`."
                )
            self._tmaps[tpair] = tmap

        logg.info("    Finish", time=start)

        return self._tmaps
示例#18
0
    def compute_transition_matrix(
        self, density_normalize: bool = True
    ) -> "ConnectivityKernel":
        """
        Compute transition matrix based on transcriptomic similarity.

        Uses symmetric, weighted KNN graph to compute symmetric transition matrix. The connectivities are computed
        using :func:`scanpy.pp.neighbors`. Depending on the parameters used there, they can be UMAP connectivities or
        gaussian-kernel-based connectivities with adaptive kernel width.

        Parameters
        ----------
        density_normalize
            Whether or not to use the underlying KNN graph for density normalization.

        Returns
        -------
        :class:`cellrank.tl.kernels.ConnectivityKernel`
            Makes :paramref:`transition_matrix` available.
        """

        start = logg.info("Computing transition matrix based on connectivities")

        params = {"dnorm": density_normalize}
        if params == self.params:
            assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG
            logg.debug(_LOG_USING_CACHE)
            logg.info("    Finish", time=start)
            return self

        self._params = params
        self._compute_transition_matrix(
            matrix=self._conn.copy(), density_normalize=density_normalize
        )

        logg.info("    Finish", time=start)

        return self
示例#19
0
    def _write_final_states(self, time=None) -> None:
        self.adata.obs[self._fs_key] = self._get(P.FIN)
        self.adata.obs[_probs(self._fs_key)] = self._get(P.FIN_PROBS)

        self.adata.uns[_colors(self._fs_key)] = self._get(A.FIN_COLORS)
        self.adata.uns[_lin_names(self._fs_key)] = list(self._get(P.FIN).cat.categories)

        extra_msg = ""
        if getattr(self, A.FIN_ABS_PROBS.s, None) is not None and hasattr(
            self, "_fin_abs_prob_key"
        ):
            # checking for None because final states can be set using `set_final_states`
            # without the probabilities in GPCCA
            self.adata.obsm[self._fin_abs_prob_key] = self._get(A.FIN_ABS_PROBS)
            extra_msg = f"       `adata.obsm[{self._fin_abs_prob_key!r}]`\n"

        logg.info(
            f"Adding `adata.obs[{_probs(self._fs_key)!r}]`\n"
            f"       `adata.obs[{self._fs_key!r}]`\n"
            f"{extra_msg}"
            f"       `.{P.FIN_PROBS}`\n"
            f"       `.{P.FIN}`",
            time=time,
        )
示例#20
0
def cluster_lineage(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineage: str,
    backward: bool = False,
    time_range: _time_range_type = None,
    clusters: Optional[Sequence[str]] = None,
    n_points: int = 200,
    time_key: str = "latent_time",
    norm: bool = True,
    recompute: bool = False,
    callback: _callback_type = None,
    ncols: int = 3,
    sharey: Union[str, bool] = False,
    key: Optional[str] = None,
    random_state: Optional[int] = None,
    use_leiden: bool = False,
    show_progress_bar: bool = True,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
    neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
    clustering_kwargs: Dict = MappingProxyType({}),
    return_models: bool = False,
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential
    lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when
        plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key
        Key in ``adata.uns`` where to save the results. If `None`, it will be saved as ``lineage_{lineage}_trend`` .
    random_state
        Random seed for reproducibility.
    use_leiden
        Whether to use :func:`scanpy.tl.leiden` for clustering or :func:`scanpy.tl.louvain`.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    clustering_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain` or :func:`scanpy.tl.leiden`.
    %(return_models)s
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s

        Also updates ``adata.uns`` with the following:

            - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of
              shape `(n_genes, n_points)` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    if key is None:
        key = f"lineage_{lineage}_trend"

    if recompute or key not in adata.uns:
        kwargs["backward"] = backward
        kwargs["time_key"] = time_key
        kwargs["n_test_points"] = n_points
        models = _create_models(model, genes, [lineage])
        all_models, models, genes, _ = _fit_bulk(
            models,
            _create_callbacks(adata, callback, genes, [lineage], **kwargs),
            genes,
            lineage,
            time_range,
            return_models=True,  # always return (better error messages)
            filter_all_failed=True,
            parallel_kwargs={
                "show_progress_bar": show_progress_bar,
                "n_jobs": _get_n_cores(n_jobs, len(genes)),
                "backend": _get_backend(models, backend),
            },
            **kwargs,
        )

        # `n_genes, n_test_points`
        trends = np.vstack(
            [model[lineage].y_test for model in models.values()]).T

        if norm:
            logg.debug("Normalizing trends")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        random_state = np.random.mtrand.RandomState(random_state).randint(
            2**16)

        pca_kwargs = dict(pca_kwargs)
        pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1)
        pca_kwargs.setdefault("random_state", random_state)
        sc.pp.pca(trends, **pca_kwargs)

        neighbors_kwargs = dict(neighbors_kwargs)
        neighbors_kwargs.setdefault("random_state", random_state)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        clustering_kwargs = dict(clustering_kwargs)
        clustering_kwargs["key_added"] = "clusters"
        clustering_kwargs.setdefault("random_state", random_state)
        try:
            if use_leiden:
                sc.tl.leiden(trends, **clustering_kwargs)
            else:
                sc.tl.louvain(trends, **clustering_kwargs)
        except ImportError as e:
            logg.warning(str(e))
            if use_leiden:
                sc.tl.louvain(trends, **clustering_kwargs)
            else:
                sc.tl.leiden(trends, **clustering_kwargs)

        logg.info(f"Saving data to `adata.uns[{key!r}]`")
        adata.uns[key] = trends
    else:
        all_models = None
        logg.info(f"Loading data from `adata.uns[{key!r}]`")
        trends = adata.uns[key]

    if "clusters" not in trends.obs:
        raise KeyError(
            "Unable to find the clustering in `trends.obs['clusters']`.")

    if clusters is None:
        clusters = trends.obs["clusters"].cat.categories
    for c in clusters:
        if c not in trends.obs["clusters"].cat.categories:
            raise ValueError(
                f"Invalid cluster name `{c!r}`. "
                f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`."
            )

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs["clusters"] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models
示例#21
0
    def compute_terminal_states(
        self,
        use: Optional[Union[int, Tuple[int], List[int], range]] = None,
        percentile: Optional[int] = 98,
        method: str = "kmeans",
        cluster_key: Optional[str] = None,
        n_clusters_kmeans: Optional[int] = None,
        n_neighbors: int = 20,
        resolution: float = 0.1,
        n_matches_min: Optional[int] = 0,
        n_neighbors_filtering: int = 15,
        basis: Optional[str] = None,
        n_comps: int = 5,
        scale: bool = False,
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
    ) -> None:
        """
        Find approximate recurrent classes of the Markov chain.

        Filter to obtain recurrent states in left eigenvectors.
        Cluster to obtain approximate recurrent classes in right eigenvectors.

        Parameters
        ----------
        use
            Which or how many first eigenvectors to use as features for clustering/filtering.
            If `None`, use the `eigengap` statistic.
        percentile
            Threshold used for filtering out cells which are most likely transient states. Cells which are in the
            lower ``percentile`` percent of each eigenvector will be removed from the data matrix.
        method
            Method to be used for clustering. Must be one of `'louvain'`, `'leiden'` or `'kmeans'`.
        cluster_key
            If a key to cluster labels is given, :attr:`{fs}` will get associated with these for naming and colors.
        n_clusters_kmeans
            If `None`, this is set to ``use + 1``.
        n_neighbors
            If we use `'louvain'` or `'leiden'` for clustering cells, we need to build a KNN graph.
            This is the :math:`K` parameter for that, the number of neighbors for each cell.
        resolution
            Resolution parameter for `'louvain'` or `'leiden'` clustering. Should be chosen relatively small.
        n_matches_min
            Filters out cells which don't have at least n_matches_min neighbors from the same class.
            This filters out some cells which are transient but have been misassigned.
        n_neighbors_filtering
            Parameter for filtering cells. Cells are filtered out if they don't have at least ``n_matches_min``
            neighbors among their ``n_neighbors_filtering`` nearest cells.
        basis
            Key from :paramref`adata` ``.obsm`` to be used as additional features for the clustering.
        n_comps
            Number of embedding components to be use when ``basis`` is not `None`.
        scale
            Scale to z-scores. Consider using this if appending embedding to features.
        %(en_cutoff_p_thresh)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`{fsp}`
                - :attr:`{fs}`
        """

        def _compute_macrostates_prob() -> Series:
            """Compute a global score of being an approximate recurrent class."""

            # get the truncated eigendecomposition
            V, evals = eig["V_l"].real[:, use], eig["D"].real[use]

            # shift and scale
            V_pos = np.abs(V)
            V_shifted = V_pos - np.min(V_pos, axis=0)
            V_scaled = V_shifted / np.max(V_shifted, axis=0)

            # check the ranges are correct
            assert np.allclose(np.min(V_scaled, axis=0), 0), "Lower limit it not zero."
            assert np.allclose(np.max(V_scaled, axis=0), 1), "Upper limit is not one."

            # further scale by the eigenvalues
            V_eigs = V_scaled / evals

            # sum over cols and scale
            c_ = np.sum(V_eigs, axis=1)
            c = c_ / np.max(c_)

            return Series(c, index=self.adata.obs_names)

        def check_use(use) -> List[int]:
            if method not in ["kmeans", "louvain", "leiden"]:
                raise ValueError(
                    f"Invalid method `{method!r}`. Valid options are `'louvain'`, `'leiden'` or `'kmeans'`."
                )

            if use is None:
                use = eig["eigengap"] + 1  # add one b/c indexing starts at 0
            if isinstance(use, int):
                use = list(range(use))
            elif not isinstance(use, (tuple, list, range)):
                raise TypeError(
                    f"Argument `use` must be either `int`, `tuple`, `list` or `range`, "
                    f"found `{type(use).__name__!r}`."
                )
            else:
                if not all(map(lambda u: isinstance(u, int), use)):
                    raise TypeError("Not all values in `use` argument are integers.")
            use = list(use)

            if len(use) == 0:
                raise ValueError(
                    f"Number of eigenvector must be larger than `0`, found `{len(use)}`."
                )

            muse = max(use)
            if muse >= eig["V_l"].shape[1] or muse >= eig["V_r"].shape[1]:
                raise ValueError(
                    f"Maximum specified eigenvector `{muse}` is larger "
                    f'than the number of computed eigenvectors `{eig["V_l"].shape[1]}`. '
                    f"Use `.compute_eigendecomposition(k={muse})` to recompute the eigendecomposition."
                )

            return use

        eig = self._get(P.EIG)
        if eig is None:
            raise RuntimeError(
                "Compute eigendecomposition first as `.compute_eigendecomposition()`."
            )
        use = check_use(use)

        start = logg.info("Computing approximate recurrent classes")
        # we check for complex values only in the left, that's okay because the complex pattern
        # will be identical for left and right
        V_l, V_r = eig["V_l"][:, use], eig["V_r"].real[:, use]
        V_l = _complex_warning(V_l, use, use_imag=False)

        # compute a rc probability
        logg.debug("Computing probabilities of approximate recurrent classes")
        self._set(A.TERM_PROBS, _compute_macrostates_prob())

        # retrieve embedding and concatenate
        if basis is not None:
            bkey = f"X_{basis}"
            if bkey not in self.adata.obsm.keys():
                raise KeyError(f"Basis key `{bkey!r}` not found in `adata.obsm`")

            X_em = self.adata.obsm[bkey][:, :n_comps]
            X = np.concatenate([V_r, X_em], axis=1)
        else:
            logg.debug("Basis is `None`. Setting X equal to the right eigenvectors")
            X = V_r

        # filter out cells which are in the lowest q percentile in abs value in each eigenvector
        if percentile is not None:
            logg.debug("Filtering out cells according to percentile")
            if percentile < 0 or percentile > 100:
                raise ValueError(
                    f"Percentile must be in interval `[0, 100]`, found `{percentile}`."
                )
            cutoffs = np.percentile(np.abs(V_l), percentile, axis=0)
            ixs = np.sum(np.abs(V_l) < cutoffs, axis=1) < V_l.shape[1]
            X = X[ixs, :]

        # scale
        if scale:
            X = zscore(X, axis=0)

        # cluster X
        if method == "kmeans" and n_clusters_kmeans is None:
            n_clusters_kmeans = len(use) + (percentile is None)
            if X.shape[0] < n_clusters_kmeans:
                raise ValueError(
                    f"Filtering resulted in only {X.shape[0]} cell(s), insufficient to cluster into "
                    f"`{n_clusters_kmeans}` clusters. Consider decreasing the value of `percentile`."
                )

        logg.debug(
            f"Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering"
        )
        labels = _cluster_X(
            X,
            method=method,
            n_clusters=n_clusters_kmeans,
            n_neighbors=n_neighbors,
            resolution=resolution,
        )

        # fill in the labels in case we filtered out cells before
        if percentile is not None:
            rc_labels = np.repeat(None, self.adata.n_obs)
            rc_labels[ixs] = labels
        else:
            rc_labels = labels

        rc_labels = Series(rc_labels, index=self.adata.obs_names, dtype="category")
        rc_labels.cat.categories = list(rc_labels.cat.categories.astype("str"))

        # filtering to get rid of some of the left over transient states
        if n_matches_min > 0:
            logg.debug(f"Filtering according to `n_matches_min={n_matches_min}`")
            distances = _get_connectivities(
                self.adata, mode="distances", n_neighbors=n_neighbors_filtering
            )
            rc_labels = _filter_cells(
                distances, rc_labels=rc_labels, n_matches_min=n_matches_min
            )

        self.set_terminal_states(
            labels=rc_labels,
            cluster_key=cluster_key,
            en_cutoff=en_cutoff,
            p_thresh=p_thresh,
            add_to_existing=False,
            time=start,
        )
示例#22
0
    def compute_eigendecomposition(
        self,
        k: int = 20,
        which: str = "LR",
        alpha: float = 1,
        only_evals: bool = False,
        ncv: Optional[int] = None,
    ) -> None:
        """
        Compute eigendecomposition of transition matrix.

        Uses a sparse implementation, if possible, and only computes the top :math:`k` eigenvectors
        to speed up the computation. Computes both left and right eigenvectors.

        Parameters
        ----------
        k
            Number of eigenvalues/vectors to compute.
        %(eigen)s
        only_evals
            Compute only eigenvalues.
        ncv
            Number of Lanczos vectors generated.

        Returns
        -------
        None
            Nothing, but updates the following field:

                - :paramref:`{prop}`
        """
        def get_top_k_evals():
            return D[np.flip(np.argsort(D.real))][:k]

        start = logg.info(
            "Computing eigendecomposition of the transition matrix")

        if self.issparse:
            logg.debug(f"Computing top `{k}` eigenvalues for sparse matrix")
            D, V_l = eigs(self.transition_matrix.T, k=k, which=which, ncv=ncv)
            if only_evals:
                self._write_eig_to_adata({
                    "D":
                    get_top_k_evals(),
                    "eigengap":
                    _eigengap(get_top_k_evals().real, alpha),
                    "params": {
                        "which": which,
                        "k": k,
                        "alpha": alpha
                    },
                })
                return
            _, V_r = eigs(self.transition_matrix, k=k, which=which, ncv=ncv)
        else:
            logg.warning(
                "This transition matrix is not sparse, computing full eigendecomposition"
            )
            D, V_l = np.linalg.eig(self.transition_matrix.T)
            if only_evals:
                self._write_eig_to_adata({
                    "D": get_top_k_evals(),
                    "eigengap": _eigengap(D.real, alpha),
                    "params": {
                        "which": which,
                        "k": k,
                        "alpha": alpha
                    },
                })
                return
            _, V_r = np.linalg.eig(self.transition_matrix)

        # Sort the eigenvalues and eigenvectors and take the real part
        logg.debug("Sorting eigenvalues by their real part")
        p = np.flip(np.argsort(D.real))
        D, V_l, V_r = D[p], V_l[:, p], V_r[:, p]
        e_gap = _eigengap(D.real, alpha)

        pi = np.abs(V_l[:, 0].real)
        pi /= np.sum(pi)

        self._write_eig_to_adata(
            {
                "D": D,
                "stationary_dist": pi,
                "V_l": V_l,
                "V_r": V_r,
                "eigengap": e_gap,
                "params": {
                    "which": which,
                    "k": k,
                    "alpha": alpha
                },
            },
            start=start,
        )
示例#23
0
    def compute_transition_matrix(
        self,
        mode: str = VelocityMode.DETERMINISTIC.s,
        backward_mode: str = BackwardMode.TRANSPOSE.s,
        softmax_scale: Optional[float] = None,
        n_samples: int = 1000,
        seed: Optional[int] = None,
        **kwargs,
    ) -> "VelocityKernel":
        """
        Compute transition matrix based on velocity directions on the local manifold.

        For each cell, infer transition probabilities based on the correlation of the cell's
        velocity-extrapolated cell state with cell states of its *K* nearest neighbors.

        Parameters
        ----------
        %(velocity_mode)s
        %(velocity_backward_mode)s
        %(softmax_scale)s
        n_samples
            Number of bootstrap samples when ``mode={m.MONTE_CARLO.s!r}``.
        seed
            Set the seed for random state when the method requires ``n_samples``.
        %(parallel)s

        Returns
        -------
        :class:`cellrank.tl.kernels.VelocityKernel`
            Makes available the following fields:

                - :paramref:`transition_matrix`
                - :paramref:`pearson_correlations`
        """

        mode = VelocityMode(mode)
        backward_mode = BackwardMode(backward_mode)

        if self.backward and mode != VelocityMode.DETERMINISTIC:
            logg.warning(
                f"Mode `{mode.s!r}` is currently not supported for the backward process. "
                f"Defaulting to mode `{VelocityMode.DETERMINISTIC.s!r}`")
            mode = VelocityMode.DETERMINISTIC
        if mode == VelocityMode.STOCHASTIC and not _HAS_JAX:
            logg.warning(
                f"Unable to detect `jax` installation. Consider installing it as `pip install jax jaxlib`.\n"
                f"Defaulting to mode `{VelocityMode.MONTE_CARLO.s!r}`")
            mode = VelocityMode.MONTE_CARLO

        start = logg.info(
            f"Computing transition matrix based on velocity correlations using `{mode.s!r}` mode"
        )

        if seed is None:
            seed = np.random.randint(0, 2**16)
        params = dict(softmax_scale=softmax_scale, mode=mode,
                      seed=seed)  # noqa
        if self.backward:
            params["bwd_mode"] = backward_mode.s

        # check whether we already computed such a transition matrix. If yes, load from cache
        if params == self._params:
            assert self._transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG
            logg.debug(_LOG_USING_CACHE, time=start)
            logg.info("    Finish", time=start)
            return self

        self._params = params

        # compute first and second order moments to model the distribution of the velocity vector
        np.random.seed(seed)
        velocity_expectation = get_moments(self.adata,
                                           self._velocity,
                                           second_order=False).astype(
                                               np.float64)
        velocity_variance = get_moments(self.adata,
                                        self._velocity,
                                        second_order=True).astype(np.float64)

        if mode == VelocityMode.MONTE_CARLO and n_samples == 1:
            logg.debug("Setting mode to sampling because `n_samples=1`")
            mode = VelocityMode.SAMPLING

        backend = kwargs.pop("backend", _DEFAULT_BACKEND)
        if version_info[:2] <= (3, 6):
            logg.warning(
                "For Python3.6, only `'threading'` backend is supported")
            backend = "threading"
        elif mode != VelocityMode.STOCHASTIC and backend == "multiprocessing":
            # this is because on jitting and pickling (cloudpickle, used by loky, handles it correctly)
            logg.warning(
                f"Multiprocessing backend is supported only for mode `{VelocityMode.STOCHASTIC.s!r}`. "
                f"Defaulting to `{_DEFAULT_BACKEND}`")
            backend = _DEFAULT_BACKEND

        if softmax_scale is None:
            logg.info(
                f"Estimating `softmax_scale` using `{VelocityMode.DETERMINISTIC.s!r}` mode"
            )
            _, cmat = _dispatch_computation(
                VelocityMode.DETERMINISTIC,
                conn=self._conn,
                expression=self._gene_expression,
                velocity=self._velocity,
                expectation=velocity_expectation,
                variance=velocity_variance,
                softmax_scale=1.0,
                backward=self.backward,
                backward_mode=backward_mode,
                n_samples=n_samples,
                seed=seed,
                backend=backend,
                **kwargs,
            )
            softmax_scale = 1.0 / np.median(np.abs(cmat.data))
            params["softmax_scale"] = softmax_scale
            logg.info(f"Setting `softmax_scale={softmax_scale:.4f}`")

        tmat, cmat = _dispatch_computation(
            mode,
            conn=self._conn,
            expression=self._gene_expression,
            velocity=self._velocity,
            expectation=velocity_expectation,
            variance=velocity_variance,
            softmax_scale=softmax_scale,
            backward=self.backward,
            backward_mode=backward_mode,
            n_samples=n_samples,
            seed=seed,
            backend=backend,
            **kwargs,
        )
        self._compute_transition_matrix(tmat, density_normalize=False)
        self._pearson_correlations = cmat

        logg.info("    Finish", time=start)

        return self
示例#24
0
def graph(
        data: Union[AnnData, np.ndarray, spmatrix],
        graph_key: Optional[str] = None,
        ixs: Optional[np.array] = None,
        layout: Union[str, Dict, Callable] = "umap",
        keys: Sequence[KEYS] = ("incoming", ),
        keylocs: Union[KEYLOCS, Sequence[KEYLOCS]] = "uns",
        node_size: float = 400,
        labels: Optional[Union[Sequence[str], Sequence[Sequence[str]]]] = None,
        top_n_edges: Optional[Union[int, Tuple[int, bool, str]]] = None,
        self_loops: bool = True,
        self_loop_radius_frac: Optional[float] = None,
        filter_edges: Optional[Tuple[float, float]] = None,
        edge_reductions: Union[Callable, Sequence[Callable]] = np.sum,
        edge_weight_scale: float = 10,
        edge_width_limit: Optional[float] = None,
        edge_alpha: float = 1.0,
        edge_normalize: bool = False,
        edge_use_curved: bool = True,
        show_arrows: bool = True,
        font_size: int = 12,
        font_color: str = "black",
        color_nodes: bool = True,
        cat_cmap: ListedColormap = cm.Set3,
        cont_cmap: ListedColormap = cm.viridis,
        legend_loc: Optional[str] = "best",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        layout_kwargs: Dict = MappingProxyType({}),
) -> None:
    """
    Plot a graph, visualizing incoming and outgoing edges or self-transitions.

    This is a utility function to look in more detail at the transition matrix in areas of interest, e.g. around an
    endpoint of development. This function is meant to visualise a small subset of nodes (~100-500) and the most likely
    transitions between them. Note that limiting edges visualized using ``top_n_edges`` will speed things up,
    as well as reduce the visual clutter.

    Parameters
    ----------
    data
        The graph data to be plotted.
    graph_key
        Key in ``adata.obsp`` or ``adata.uns`` where the graph is stored. Only used
        when ``data`` is :class:`~anndata.Anndata` object.
    ixs
        Subset of indices of the graph to visualize.
    layout
        Layout to use for graph drawing.

        - If :class:`str`, search for embedding in ``adata.obsm['X_{layout}']``.
          Use ``layout_kwargs={'components': [0, 1]}`` to select components.
        - If :class:`dict`, keys should be values in interval ``[0, len(ixs))``
          and values `(x, y)` pairs corresponding to node positions.
    keys
        Keys in ``adata.obs``, ``adata.obsm`` or ``adata.obsp`` to color the nodes.

        - If `'incoming'`, `'outgoing'` or `'self_loops'`, visualize reduction (see ``edge_reductions``)
          for each node based on incoming or outgoing edges, respectively.
    keylocs
        Locations of ``keys``. Can be any attribute of ``data`` if it's :class:`anndata.AnnData` object.
    node_size
        Size of the nodes.
    labels
        Labels of the nodes.
    top_n_edges
        Either top N outgoing edges in descending order or a tuple
        ``(top_n_edges, in_ascending_order, {'incoming', 'outgoing'})``. If `None`, show all edges.
    self_loops
        Whether visualize self transitions and also to consider them in ``top_n_edges``.
    self_loop_radius_frac
        Fraction of a unit circle to visualize self transitions. If `None`, use ``node_size / 1000``.
    filter_edges
        Whether to remove all edges not in `[min, max]` interval.
    edge_reductions
        Aggregation function to use when coloring nodes by edge weights.
    edge_weight_scale
        Number by which to scale the width of the edges. Useful when the weights are small.
    edge_width_limit
        Upper bound for the width of the edges. Useful when weights are unevenly distributed.
    edge_alpha
        Alpha channel value for edges and arrows.
    edge_normalize
        If `True`, normalize edges to `[0, 1]` interval prior to applying any scaling or truncation.
    edge_use_curved
        If `True`, use curved edges. This can improve visualization at a small performance cost.
    show_arrows
        Whether to show the arrows. Setting this to `False` may dramatically speed things up.
    font_size
        Font size for node labels.
    font_color
        Label color of the nodes.
    color_nodes
        Whether to color the nodes
    cat_cmap
        Categorical colormap used when ``keys`` contain categorical variables.
    cont_cmap
        Continuous colormap used when ``keys`` contain continuous variables.
    legend_loc
        Location of the legend.
    %(plotting)s
    layout_kwargs
        Additional kwargs for ``layout``.

    Returns
    -------
    %(just_plots)s
    """

    from anndata import AnnData as _AnnData

    import networkx as nx

    def plot_arrows(curves, G, pos, ax, edge_weight_scale):
        for line, (edge, val) in zip(curves, G.edges.items()):
            if edge[0] == edge[1]:
                continue

            mask = (~np.isnan(line)).all(axis=1)
            line = line[mask, :]
            if not len(line):  # can be all NaNs
                continue

            line = line.reshape((-1, 2))
            X, Y = line[:, 0], line[:, 1]

            node_start = pos[edge[0]]
            # reverse
            if np.where(np.isclose(node_start - line,
                                   [0, 0]).all(axis=1))[0][0]:
                X, Y = X[::-1], Y[::-1]

            mid = len(X) // 2
            posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2])  # noqa

            arrow = FancyArrowPatch(
                posA=posA,
                posB=posB,
                # we clip because too small values
                # cause it to crash
                arrowstyle=ArrowStyle.CurveFilledB(
                    head_length=np.clip(
                        val["weight"] * edge_weight_scale * 4,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                    head_width=np.clip(
                        val["weight"] * edge_weight_scale * 2,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                ),
                color="k",
                zorder=float("inf"),
                alpha=edge_alpha,
                linewidth=0,
            )
            ax.add_artist(arrow)

    def normalize_weights():
        weights = np.array([v["weight"] for v in G.edges.values()])
        minn = np.min(weights)
        weights = (weights - minn) / (np.max(weights) - minn)
        for v, w in zip(G.edges.values(), weights):
            v["weight"] = w

    def remove_top_n_edges():
        if top_n_edges is None:
            return

        if isinstance(top_n_edges, (tuple, list)):
            to_keep, ascending, group_by = top_n_edges
        else:
            to_keep, ascending, group_by = top_n_edges, False, "out"

        if group_by not in ("incoming", "outgoing"):
            raise ValueError(
                "Argument `groupby` in `top_n_edges` must be either `'incoming`' or `'outgoing'`."
            )

        source, target = zip(*G.edges)
        weights = [v["weight"] for v in G.edges.values()]
        tmp = pd.DataFrame({
            "outgoing": source,
            "incoming": target,
            "w": weights
        })

        if not self_loops:
            # remove self loops
            tmp = tmp[tmp["incoming"] != tmp["outgoing"]]

        to_keep = set(
            map(
                tuple,
                tmp.groupby(group_by).apply(
                    lambda g: g.sort_values("w", ascending=ascending).take(
                        range(min(to_keep, len(g)))))[["outgoing",
                                                       "incoming"]].values,
            ))

        for e in list(G.edges):
            if e not in to_keep:
                G.remove_edge(*e)

    def remove_low_weight_edges():
        if filter_edges is None or filter_edges == (None, None):
            return

        minn, maxx = filter_edges
        minn = minn if minn is not None else -np.inf
        maxx = maxx if maxx is not None else np.inf

        for e, attr in list(G.edges.items()):
            if attr["weight"] < minn or attr["weight"] > maxx:
                G.remove_edge(*e)

    _min_edge_weight = 0.00001

    if edge_width_limit is None:
        logg.debug("Not limiting width of edges")
        edge_width_limit = float("inf")

    if self_loop_radius_frac is None:
        self_loop_radius_frac = (node_size /
                                 2000 if node_size >= 200 else node_size /
                                 1000)
        logg.debug(
            f"Setting self loop radius fraction to `{self_loop_radius_frac}`")

    if not isinstance(keylocs, (tuple, list)):
        keylocs = [keylocs] * len(keys)
    elif len(keylocs) == 1:
        keylocs = keylocs * 3
    elif all(map(lambda k: k in ("incoming", "outgoing", "self_loops"), keys)):
        # don't care about keylocs since they are irrelevant
        logg.debug("Ignoring key locations")
        keylocs = [None] * len(keys)

    if not isinstance(edge_reductions, (tuple, list)):
        edge_reductions = [edge_reductions] * len(keys)
    if not all(map(callable, edge_reductions)):
        raise ValueError("Not all `edge_reductions` functions are callable.")

    if not isinstance(labels, (tuple, list)):
        labels = [labels] * len(keys)
    elif not len(labels):
        labels = [None] * len(keys)
    elif not isinstance(labels[0], (tuple, list)):
        labels = [labels] * len(keys)

    if len(keys) != len(labels):
        raise ValueError(
            f"`Keys` and `labels` must be of the same shape, found `{len(keys)}` and `{len(labels)}`."
        )

    if isinstance(data, _AnnData):
        if graph_key is None:
            raise ValueError(
                "Argument `graph_key` cannot be `None` when `data` is `anndata.Anndata` object."
            )
        gdata = _read_graph_data(data, graph_key)
    elif isinstance(data, (np.ndarray, spmatrix)):
        gdata = data
    else:
        raise TypeError(
            f"Expected argument `data` to be one of `anndata.AnnData`, `numpy.ndarray`, `scipy.sparse.spmatrix`, "
            f"found `{type(data).__name__!r}`.")
    is_sparse = issparse(gdata)

    if ixs is not None:
        gdata = gdata[ixs, :][:, ixs]
    else:
        ixs = list(range(gdata.shape[0]))

    start = logg.info("Creating graph")
    G = (nx.from_scipy_sparse_matrix(gdata, create_using=nx.DiGraph)
         if is_sparse else nx.from_numpy_array(gdata, create_using=nx.DiGraph))

    remove_low_weight_edges()
    remove_top_n_edges()
    if edge_normalize:
        normalize_weights()
    logg.info("    Finish", time=start)

    # do NOT recreate the graph, for the edge reductions
    # gdata = nx.to_numpy_array(G)

    if figsize is None:
        figsize = (12, 8 * len(keys))

    fig, axes = plt.subplots(nrows=len(keys),
                             ncols=1,
                             figsize=figsize,
                             dpi=dpi)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = np.ravel(axes)

    if isinstance(layout, str):
        if f"X_{layout}" not in data.obsm:
            raise KeyError(
                f"Unable to find embedding `'X_{layout}'` in `adata.obsm`.")
        components = layout_kwargs.get("components", [0, 1])
        if len(components) != 2:
            raise ValueError(
                f"Components in `layout_kwargs` must be of length `2`, found `{len(components)}`."
            )
        emb = data.obsm[f"X_{layout}"][:, components]
        pos = {i: emb[ix, :] for i, ix in enumerate(ixs)}
        logg.info(f"Embedding graph using `{layout!r}` layout")
    elif isinstance(layout, dict):
        rng = range(len(ixs))
        for k, v in layout.items():
            if k not in rng:
                raise ValueError(
                    f"Key in `layout` must be in `range(len(ixs))`, found `{k}`."
                )
            if len(v) != 2:
                raise ValueError(
                    f"Value in `layout` must be a `tuple` or a `list` of length 2, found `{len(v)}`."
                )
        pos = layout
        logg.debug("Using precomputed layout")
    elif callable(layout):
        start = logg.info(
            f"Embedding graph using `{layout.__name__!r}` layout")
        pos = layout(G, **layout_kwargs)
        logg.info("    Finish", time=start)
    else:
        raise TypeError(f"Argument `layout` must be either a `string`, "
                        f"a `dict` or a `callable`, found `{type(layout)}`.")

    curves, lc = None, None
    if edge_use_curved:
        try:
            from ._utils import _curved_edges

            logg.debug("Creating curved edges")
            curves = _curved_edges(G,
                                   pos,
                                   self_loop_radius_frac,
                                   polarity="directed")
            lc = LineCollection(
                curves,
                colors="black",
                linewidths=np.clip(
                    np.ravel([v["weight"]
                              for v in G.edges.values()]) * edge_weight_scale,
                    0,
                    edge_width_limit,
                ),
                alpha=edge_alpha,
            )
        except ImportError as e:
            global _msg_shown
            if not _msg_shown:
                print(
                    str(e)[:-1],
                    "in order to use curved edges or specify `edge_use_curved=False`.",
                )
                _msg_shown = True

    for ax, keyloc, key, labs, er in zip(axes, keylocs, keys, labels,
                                         edge_reductions):
        label_col = {}  # dummy value

        if key in ("incoming", "outgoing", "self_loops"):
            if key in ("incoming", "outgoing"):
                vals = er(gdata, axis=int(key == "outgoing"))
                if issparse(vals):
                    vals = vals.A
                vals = vals.flatten()
            else:
                vals = gdata.diagonal() if is_sparse else np.diag(gdata)
            node_v = dict(zip(pos.keys(), vals))
        else:
            label_col = getattr(data, keyloc)
            if key in label_col:
                node_v = dict(zip(pos.keys(), label_col[key]))
            else:
                raise RuntimeError(
                    f"Key `{key!r}` not found in `adata.{keyloc}`.")

        if labs is not None:
            if len(labs) != len(pos):
                raise RuntimeError(
                    f"Number of labels ({len(labels)}) and nodes ({len(pos)}) mismatch."
                )
            nx.draw_networkx_labels(
                G,
                pos,
                labels=labs if isinstance(labs, dict) else dict(
                    zip(pos.keys(), labs)),
                ax=ax,
                font_color=font_color,
                font_size=font_size,
            )

        if lc is not None and curves is not None:
            ax.add_collection(deepcopy(lc))  # copying necessary
            if show_arrows:
                plot_arrows(curves, G, pos, ax, edge_weight_scale)
        else:
            nx.draw_networkx_edges(
                G,
                pos,
                width=[
                    np.clip(
                        v["weight"] * edge_weight_scale,
                        _min_edge_weight,
                        edge_width_limit,
                    ) for _, v in G.edges.items()
                ],
                alpha=edge_alpha,
                edge_color="black",
                arrows=True,
                arrowstyle="-|>",
            )

        if key in label_col and is_categorical_dtype(label_col[key]):
            values = label_col[key]
            if keyloc in ("obs", "obsm"):
                values = values[ixs]
            categories = values.cat.categories
            color_key = _colors(key)
            if color_key in data.uns:
                mapper = dict(zip(categories, data.uns[color_key]))
            else:
                mapper = dict(
                    zip(categories, map(cat_cmap.get, range(len(categories)))))

            colors = []
            seen = set()

            for v in values:
                colors.append(mapper[v])
                seen.add(v)

            nodes_kwargs = dict(cmap=cat_cmap, node_color=colors)  # noqa
            if legend_loc is not None:
                x, y = pos[0]
                for label in sorted(seen):
                    ax.plot([x], [y], label=label, color=mapper[label])
                ax.legend(loc=legend_loc)
        else:
            values = list(node_v.values())
            vmin, vmax = np.min(values), np.max(values)
            nodes_kwargs = dict(  # noqa
                cmap=cont_cmap,
                node_color=values,
                vmin=vmin,
                vmax=vmax)

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="1.5%", pad=0.05)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          cmap=cont_cmap,
                                          norm=mpl.colors.Normalize(vmin=vmin,
                                                                    vmax=vmax))

        if color_nodes is False:
            nodes_kwargs = {}

        nx.draw_networkx_nodes(G,
                               pos,
                               node_size=node_size,
                               ax=ax,
                               **nodes_kwargs)

        ax.set_title(key)
        ax.axis("off")

    if save is not None:
        save_fig(fig, save)

    fig.show()
示例#25
0
    def compute_schur(
        self,
        n_components: int = 10,
        initial_distribution: Optional[np.ndarray] = None,
        method: str = "krylov",
        which: str = "LR",
        alpha: float = 1,
    ):
        """
        Compute the Schur decomposition.

        Parameters
        ----------
        n_components
            Number of vectors to compute.
        initial_distribution
            Input probability distribution over all cells. If `None`, uniform is chosen.
        method
            Method for calculating the Schur vectors. Valid options are: `'krylov'` or `'brandts'`.
            For benefits of each method, see :class:`msmtools.analysis.dense.gpcca.GPCCA`. The former is
            an iterative procedure that computes a partial, sorted Schur decomposition for large, sparse
            matrices whereas the latter computes a full sorted Schur decomposition of a dense matrix.
        %(eigen)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :paramref:`{schur_vectors}`
                - :paramref:`{schur_matrix}`
                - :paramref:`{eigendec}`
        """

        if n_components < 2:
            raise ValueError(
                f"Number of components must be `>=2`, found `{n_components}`.")

        self._gpcca = _GPCCA(self.transition_matrix,
                             eta=initial_distribution,
                             z=which,
                             method=method)
        start = logg.info("Computing Schur decomposition")

        try:
            self._gpcca._do_schur_helper(n_components)
        except ValueError:
            logg.warning(
                f"Using `{n_components}` components would split a block of complex conjugates. "
                f"Increasing `n_components` to `{n_components + 1}`")
            self._gpcca._do_schur_helper(n_components + 1)

        # make it available for pl
        setattr(self, A.SCHUR.s, self._gpcca.X)
        setattr(self, A.SCHUR_MAT.s, self._gpcca.R)

        self._invalid_n_states = np.array([
            i for i in range(2, len(self._gpcca.eigenvalues))
            if _check_conj_split(self._gpcca.eigenvalues[:i])
        ])
        if len(self._invalid_n_states):
            logg.info(
                f"When computing macrostates, choose a number of states NOT in `{list(self._invalid_n_states)}`"
            )

        self._write_eig_to_adata(
            {
                "D": self._gpcca.eigenvalues,
                "eigengap": _eigengap(self._gpcca.eigenvalues, alpha),
                "params": {
                    "which": which,
                    "k": len(self._gpcca.eigenvalues),
                    "alpha": alpha,
                },
            },
            start=start,
            extra_msg=
            f"\n       `.{P.SCHUR}`\n       `.{P.SCHUR_MAT}`\n    Finish",
        )
示例#26
0
def lineages(
    adata: AnnData,
    backward: bool = False,
    copy: bool = False,
    return_estimator: bool = False,
    **kwargs,
) -> Optional[AnnData]:
    """
    Compute probabilistic lineage assignment using RNA velocity.

    For each cell `i` in :math:`{1, ..., N}` and %(initial_or_terminal)s state `j` in :math:`{1, ..., M}`,
    the probability is computed that cell `i` is either going to %(terminal)s state `j` (``backward=False``)
    or is coming from %(initial)s state `j` (``backward=True``).

    This function computes the absorption probabilities of a Markov chain towards the %(initial_or_terminal) states
    uncovered by :func:`cellrank.tl.initial_states` or :func:`cellrank.tl.terminal_states` using a highly efficient
    implementation that scales to large cell numbers.

    It's also possible to calculate mean and variance of the time until absorption for all or just a subset
    of the %(initial_or_terminal)s states. This can be seen as a pseudotemporal measure, either towards any terminal
    population of the state change trajectory, or towards specific ones.

    Parameters
    ----------
    %(adata)s
    %(backward)s
    copy
        Whether to update the existing ``adata`` object or to return a copy.
    return_estimator
        Whether to return the estimator. Only available when ``copy=False``.
    **kwargs
        Keyword arguments for :meth:`cellrank.tl.estimators.BaseEstimator.compute_absorption_probabilities`.

    Returns
    -------
    :class:`anndata.AnnData`, :class:`cellrank.tl.estimators.BaseEstimator` or :obj:`None`
        Depending on ``copy`` and ``return_estimator``, either updates the existing ``adata`` object,
        returns its copy or returns the estimator.
    """

    if backward:
        lin_key = AbsProbKey.BACKWARD
        fs_key = TermStatesKey.BACKWARD
        fs_key_pretty = TerminalStatesPlot.BACKWARD
    else:
        lin_key = AbsProbKey.FORWARD
        fs_key = TermStatesKey.FORWARD
        fs_key_pretty = TerminalStatesPlot.FORWARD

    try:
        pk = PrecomputedKernel(adata=adata, backward=backward)
    except KeyError as e:
        raise RuntimeError(
            f"Compute transition matrix first as `cellrank.tl.transition_matrix(..., backward={backward})`."
        ) from e

    start = logg.info(
        f"Computing lineage probabilities towards {fs_key_pretty.s}")
    mc = GPCCA(
        pk, read_from_adata=True, inplace=not copy
    )  # GPCCA is more general than CFLARE, in terms of what is saves
    if mc._get(P.TERM) is None:
        raise RuntimeError(
            f"Compute the states first as `cellrank.tl.{fs_key.s}(..., backward={backward})`."
        )

    # compute the absorption probabilities
    mc.compute_absorption_probabilities(**kwargs)

    logg.info(f"Adding lineages to `adata.obsm[{lin_key.s!r}]`\n    Finish",
              time=start)

    return mc.adata if copy else mc if return_estimator else None
示例#27
0
    def compute_macrostates(
        self,
        n_states: Optional[Union[int, Tuple[int, int], List[int],
                                 Dict[str, int]]] = None,
        n_cells: Optional[int] = 30,
        use_min_chi: bool = False,
        cluster_key: str = None,
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
    ):
        """
        Compute the macrostates.

        Parameters
        ----------
        n_states
            Number of macrostates. If `None`, use the `eigengap` heuristic.
        %(n_cells)s
        use_min_chi
            Whether to use :meth:`pygpcca.GPCCA.minChi` to calculate the number of macrostates.
            If `True`, ``n_states`` corresponds to a closed interval `[min, max]` inside of which the potentially
            optimal number of macrostates is searched.
        cluster_key
            If a key to cluster labels is given, names and colors of the states will be associated with the clusters.
        %(en_cutoff_p_thresh)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`{msp}`
                - :attr:`{ms}`
                - :attr:`{schur}`
                - :attr:`{coarse_T}`
                - :attr:`{coarse_stat}`
        """

        was_from_eigengap = False

        if use_min_chi:
            n_states = self._get_n_states_from_minchi(n_states)

        if n_states is None:
            if self._get(P.EIG) is None:
                raise RuntimeError(
                    "Compute eigendecomposition first as `.compute_eigendecomposition()` or `.compute_schur()`."
                )
            was_from_eigengap = True
            n_states = self._get(P.EIG)["eigengap"] + 1
            logg.info(f"Using `{n_states}` states based on eigengap")
        elif not isinstance(n_states, int):
            raise ValueError(
                f"Expected `n_states` to be an integer when `use_min_chi=False`, "
                f"found `{type(n_states).__name__!r}`.")

        if n_states <= 0:
            raise ValueError(
                f"Expected `n_states` to be positive or `None`, found `{n_states}`."
            )

        n_states = self._check_states_validity(n_states)
        if n_states == 1:
            self._compute_one_macrostate(
                n_cells=n_cells,
                cluster_key=cluster_key,
                p_thresh=p_thresh,
                en_cutoff=en_cutoff,
            )
            return

        if self._gpcca is None:
            if not was_from_eigengap:
                raise RuntimeError(
                    "Compute Schur decomposition first as `.compute_schur()`.")

            logg.warning(
                f"Number of states `{n_states}` was automatically determined by `eigengap` "
                "but no Schur decomposition was found. Computing with default parameters"
            )
            # this cannot fail if splitting occurs
            # if it were to split, it's automatically increased in `compute_schur`
            self.compute_schur(n_states)

        # pre-computed X
        if self._gpcca._p_X.shape[1] < n_states:
            logg.warning(
                f"Requested more macrostates `{n_states}` than available "
                f"Schur vectors `{self._gpcca._p_X.shape[1]}`. Recomputing the decomposition"
            )

        start = logg.info(f"Computing `{n_states}` macrostates")
        try:
            self._gpcca = self._gpcca.optimize(m=n_states)
        except ValueError as e:
            # this is the following case - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev.
            # in the try block, Schur decomposition with 5 vectors is computed, but it fails (no way of knowing)
            # so in this case, we increase it by 1
            n_states += 1
            logg.warning(f"{e}\nIncreasing `n_states` to `{n_states}`")
            self._gpcca = self._gpcca.optimize(m=n_states)

        self._set_macrostates(
            memberships=self._gpcca.memberships,
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )

        # cache the results and make sure we don't overwrite
        self._set(A.SCHUR, self._gpcca._p_X)
        self._set(A.SCHUR_MAT, self._gpcca._p_R)

        names = self._get(P.MACRO_MEMBER).names

        self._set(
            A.COARSE_T,
            pd.DataFrame(
                self._gpcca.coarse_grained_transition_matrix,
                index=names,
                columns=names,
            ),
        )
        self._set(
            A.COARSE_INIT_D,
            pd.Series(self._gpcca.coarse_grained_input_distribution,
                      index=names),
        )

        # careful here, in case computing the stat. dist failed
        if self._gpcca.coarse_grained_stationary_probability is not None:
            self._set(
                A.COARSE_STAT_D,
                pd.Series(
                    self._gpcca.coarse_grained_stationary_probability,
                    index=names,
                ),
            )
            logg.info(
                f"Adding `.{P.MACRO_MEMBER}`\n"
                f"       `.{P.MACRO}`\n"
                f"       `.{P.SCHUR}`\n"
                f"       `.{P.COARSE_T}`\n"
                f"       `.{P.COARSE_STAT_D}`\n"
                f"    Finish",
                time=start,
            )
        else:
            logg.warning("No stationary distribution found in GPCCA object")
            logg.info(
                f"Adding `.{P.MACRO_MEMBER}`\n"
                f"       `.{P.MACRO}`\n"
                f"       `.{P.SCHUR}`\n"
                f"       `.{P.COARSE_T}`\n"
                f"    Finish",
                time=start,
            )
示例#28
0
    def compute_gdpt(self,
                     n_components: int = 10,
                     key_added: str = "gdpt_pseudotime",
                     **kwargs):
        """
        Compute generalized Diffusion pseudotime from :cite:`haghverdi:16` using the real Schur decomposition.

        Parameters
        ----------
        n_components
            Number of real Schur vectors to consider.
        key_added
            Key in :attr:`adata` ``.obs`` where to save the pseudotime.
        kwargs
            Keyword arguments for :meth:`cellrank.tl.GPCCA.compute_schur` if Schur decomposition is not found.

        Returns
        -------
        None
            Nothing, just updates :attr:`adata` ``.obs[key_added]`` with the computed pseudotime.
        """
        def _get_dpt_row(e_vals: np.ndarray, e_vecs: np.ndarray, i: int):
            row = sum(
                (np.abs(e_vals[eval_ix]) / (1 - np.abs(e_vals[eval_ix])) *
                 (e_vecs[i, eval_ix] - e_vecs[:, eval_ix]))**2
                # account for float32 precision
                for eval_ix in range(0, e_vals.size)
                if np.abs(e_vals[eval_ix]) < 0.9994)

            return np.sqrt(row)

        if "iroot" not in self.adata.uns.keys():
            raise KeyError("Key `'iroot'` not found in `adata.uns`.")

        iroot = self.adata.uns["iroot"]
        if isinstance(iroot, str):
            iroot = np.where(self.adata.obs_names == iroot)[0]
            if not len(iroot):
                raise ValueError(
                    f"Unable to find cell with name `{self.adata.uns['iroot']!r}` in `adata.obs_names`."
                )
            iroot = iroot[0]

        if n_components < 2:
            raise ValueError(
                f"Expected number of components >= 2, found `{n_components}`.")

        if self._get(P.SCHUR) is None:
            logg.warning("No Schur decomposition found. Computing")
            self.compute_schur(n_components, **kwargs)
        elif self._get(P.SCHUR_MAT).shape[1] < n_components:
            logg.warning(
                f"Requested `{n_components}` components, but only `{self._get(P.SCHUR_MAT).shape[1]}` were found. "
                f"Recomputing using default values")
            self.compute_schur(n_components)
        else:
            logg.debug("Using cached Schur decomposition")

        start = logg.info(
            f"Computing Generalized Diffusion Pseudotime using `n_components={n_components}`"
        )

        Q, eigenvalues = (
            self._get(P.SCHUR),
            self._get(P.EIG)["D"],
        )
        # may have to remove some values if too many converged
        Q, eigenvalues = Q[:, :n_components], eigenvalues[:n_components]
        D = _get_dpt_row(eigenvalues, Q, i=iroot)
        pseudotime = D / np.max(D[np.isfinite(D)])

        self.adata.obs[key_added] = pseudotime

        logg.info(f"Adding `{key_added!r}` to `adata.obs`\n    Finish",
                  time=start)
示例#29
0
def _initial_terminal(
    adata: AnnData,
    estimator: type(BaseEstimator) = GPCCA,
    backward: bool = False,
    mode: str = VelocityMode.DETERMINISTIC.s,
    backward_mode: str = BackwardMode.TRANSPOSE.s,
    n_states: Optional[int] = None,
    cluster_key: Optional[str] = None,
    key: Optional[str] = None,
    show_plots: bool = False,
    copy: bool = False,
    return_estimator: bool = False,
    fit_kwargs: Mapping = MappingProxyType({}),
    **kwargs,
) -> Optional[Union[AnnData, BaseEstimator]]:

    _check_estimator_type(estimator)

    try:
        kernel = PrecomputedKernel(key, adata=adata, backward=backward)
        write_to_adata = False  # no need to write
        logg.info("Using precomputed transition matrix")
    except KeyError:
        # compute kernel object
        kernel = transition_matrix(
            adata,
            backward=backward,
            mode=mode,
            backward_mode=backward_mode,
            **kwargs,
        )
        write_to_adata = True

    # create estimator object
    mc = estimator(
        kernel,
        read_from_adata=False,
        inplace=not copy,
        key=key,
        write_to_adata=write_to_adata,
    )

    if cluster_key is None:
        _info_if_obs_keys_categorical_present(
            adata,
            keys=["louvain", "leiden", "clusters"],
            msg_fmt=
            "Found categorical observation in `adata.obs[{!r}]`. Consider specifying it as `cluster_key`.",
        )

    mc.fit(
        n_lineages=n_states,
        cluster_key=cluster_key,
        compute_absorption_probabilities=False,
        **fit_kwargs,
    )

    if show_plots:
        mc.plot_spectrum(real_only=True)
        if isinstance(mc, CFLARE):
            mc.plot_eigendecomposition(abs_value=True,
                                       perc=[0, 98],
                                       use=n_states)
            mc.plot_terminal_states(discrete=True, same_plot=False)
        elif isinstance(mc, GPCCA):
            n_states = len(mc._get(P.MACRO).cat.categories)
            if n_states > 1:
                mc.plot_schur()
            mc.plot_terminal_states(discrete=True, same_plot=False)
            if n_states > 1:
                mc.plot_coarse_T()
        else:
            raise NotImplementedError(
                f"Pipeline not implemented for `{type(mc).__name__!r}.`")

    return mc.adata if copy else mc if return_estimator else None
示例#30
0
def gene_trends(
    adata: AnnData,
    model: _input_model_type,
    genes: Union[str, Sequence[str]],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    data_key: str = "X",
    time_key: str = "latent_time",
    transpose: bool = False,
    time_range: Optional[Union[_time_range_type,
                               List[_time_range_type]]] = None,
    callback: _callback_type = None,
    conf_int: Union[bool, float] = True,
    same_plot: bool = False,
    hide_cells: bool = False,
    perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float,
                                                             float]]]] = None,
    lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
    abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
    cell_color: Optional[str] = None,
    cell_alpha: float = 0.6,
    lineage_alpha: float = 0.2,
    size: float = 15,
    lw: float = 2,
    cbar: bool = True,
    margins: float = 0.015,
    sharex: Optional[Union[str, bool]] = None,
    sharey: Optional[Union[str, bool]] = None,
    gene_as_title: Optional[bool] = None,
    legend_loc: Optional[str] = "best",
    obs_legend_loc: Optional[str] = "best",
    ncols: int = 2,
    suptitle: Optional[str] = None,
    return_models: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    plot_kwargs: Mapping = MappingProxyType({}),
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression,
    where we take the lineage weights into account in the loss function.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages to plot. If `None`, plot all lineages.
    %(backward)s
    data_key
        Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    transpose
        If ``same_plot=True``, group the trends by ``lineages`` instead of ``genes``. This enforces ``hide_cells=True``.
        If ``same_plot=False``, show ``lineages`` in rows and ``genes`` in columns.
    %(model_callback)s
    conf_int
        Whether to compute and show confidence interval. If the :paramref:`model` is :class:`cellrank.ul.models.GAMR`,
        it can also specify the confidence level, the default is `0.95`.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all cells.
    perc
        Percentile for colors. Valid values are in interval `[0, 100]`.
        This can improve visualization. Can be specified individually for each lineage.
    lineage_cmap
        Categorical colormap to use when coloring in the lineages. If `None` and ``same_plot``,
        use the corresponding colors in ``adata.uns``, otherwise use `'black'`.
    abs_prob_cmap
        Continuous colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when ``same_plot=False``.
    cell_color
        Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names` used for coloring the cells.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``.
    margins
        Margins around the plot.
    sharex
        Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`.
    sharey
        Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`.
    gene_as_title
        Whether to show gene names as titles instead on y-axis.
    legend_loc
        Location of the legend displaying lineages. Only used when `same_plot=True`.
    obs_legend_loc
        Location of the legend when ``cell_color`` corresponds to a categorical variable.
    ncols
        Number of columns of the plot when plotting multiple genes. Only used when ``same_plot=True``.
    suptitle
        Suptitle of the figure.
    %(return_models)s
    %(parallel)s
    %(plotting)s
    plot_kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    if data_key != "obs":
        _check_collection(adata,
                          genes,
                          "var_names",
                          use_raw=kwargs.get("use_raw", False))
    else:
        _check_collection(adata,
                          genes,
                          "obs",
                          use_raw=kwargs.get("use_raw", False))

    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(ln is None
             for ln in lineages):  # no lineage, all the weights are 1
        lineages = [None]
        cbar = False
        logg.debug("All lineages are `None`, setting the weights to `1`")
    lineages = _unique_order_preserving(lineages)

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    kwargs["time_key"] = time_key
    kwargs["data_key"] = data_key
    kwargs["backward"] = backward
    kwargs["conf_int"] = conf_int  # prepare doesnt take or need this
    models = _create_models(model, genes, lineages)

    all_models, models, genes, lineages = _fit_bulk(
        models,
        _create_callbacks(adata, callback, genes, lineages, **kwargs),
        genes,
        lineages,
        time_range,
        return_models=True,
        filter_all_failed=False,
        parallel_kwargs={
            "show_progress_bar": show_progress_bar,
            "n_jobs": _get_n_cores(n_jobs, len(genes)),
            "backend": _get_backend(models, backend),
        },
        **kwargs,
    )

    lineages = sorted(lineages)
    tmp = adata.obsm[ln_key][lineages].colors
    if lineage_cmap is None and not transpose:
        lineage_cmap = tmp

    plot_kwargs = dict(plot_kwargs)
    plot_kwargs["obs_legend_loc"] = obs_legend_loc
    if transpose:
        all_models = pd.DataFrame(all_models).T.to_dict()
        models = pd.DataFrame(models).T.to_dict()
        genes, lineages = lineages, genes
        hide_cells = same_plot or hide_cells
    else:
        # information overload otherwise
        plot_kwargs["lineage_probability"] = False
        plot_kwargs["lineage_probability_conf_int"] = False

    tmp = pd.DataFrame(models).T.astype(bool)
    start_rows = np.argmax(tmp.values, axis=0)
    end_rows = tmp.shape[0] - np.argmax(tmp[::-1].values, axis=0) - 1

    if same_plot:
        gene_as_title = True if gene_as_title is None else gene_as_title
        sharex = "all" if sharex is None else sharex
        if sharey is None:
            sharey = "row" if plot_kwargs.get("lineage_probability",
                                              False) else "none"
        ncols = len(genes) if ncols >= len(genes) else ncols
        nrows = int(np.ceil(len(genes) / ncols))
    else:
        gene_as_title = False if gene_as_title is None else gene_as_title
        sharex = "col" if sharex is None else sharex
        if sharey is None:
            sharey = ("row" if not hide_cells or plot_kwargs.get(
                "lineage_probability", False) else "none")
        nrows = len(genes)
        ncols = len(lineages)

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = time_key

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=sharex,
        sharey=sharey,
        figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize,
        tight_layout=True,
        dpi=dpi,
    )
    axes = np.reshape(axes, (nrows, ncols))

    cnt = 0
    plot_kwargs["obs_legend_loc"] = None if same_plot else obs_legend_loc

    logg.info("Plotting trends")
    for row in range(len(axes)):
        for col in range(len(axes[row])):
            if cnt >= len(genes):
                break
            gene = genes[cnt]
            if (same_plot and plot_kwargs.get("lineage_probability", False)
                    and transpose):
                lpc = adata.obsm[ln_key][gene].colors[0]
            else:
                lpc = None

            if same_plot:
                plot_kwargs["obs_legend_loc"] = (obs_legend_loc if row == 0
                                                 and col == len(axes[0]) - 1
                                                 else None)

            _trends_helper(
                models,
                gene=gene,
                lineage_names=lineages,
                transpose=transpose,
                same_plot=same_plot,
                hide_cells=hide_cells,
                perc=perc,
                lineage_cmap=lineage_cmap,
                abs_prob_cmap=abs_prob_cmap,
                lineage_probability_color=lpc,
                cell_color=cell_color,
                alpha=cell_alpha,
                lineage_alpha=lineage_alpha,
                size=size,
                lw=lw,
                cbar=cbar,
                margins=margins,
                sharey=sharey,
                gene_as_title=gene_as_title,
                legend_loc=legend_loc,
                figsize=figsize,
                fig=fig,
                axes=axes[row, col] if same_plot else axes[cnt],
                show_ylabel=col == 0,
                show_lineage=same_plot or (cnt == start_rows),
                show_xticks_and_label=((row + 1) * ncols + col >= len(genes))
                if same_plot else (cnt == end_rows),
                **plot_kwargs,
            )
            # plot legend on the 1st plot
            cnt += 1

            if not same_plot:
                plot_kwargs["obs_legend_loc"] = None

    if same_plot and (col != ncols):
        for ax in np.ravel(axes)[cnt:]:
            ax.remove()

    fig.suptitle(suptitle, y=1.05)

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models