예제 #1
0
def lineage_drivers(
    adata: AnnData,
    backward: bool = False,
    lineages: Optional[Union[Sequence, str]] = None,
    method: str = TestMethod.FISCHER.s,
    cluster_key: Optional[str] = None,
    clusters: Optional[Union[Sequence, str]] = None,
    layer: str = "X",
    use_raw: bool = False,
    confidence_level: float = 0.95,
    n_perms: int = 1000,
    seed: Optional[int] = None,
    return_drivers: bool = True,
    **kwargs,
) -> Optional[pd.DataFrame]:  # noqa
    """
    %(lineage_drivers.full_desc)s

    Parameters
    ----------
    %(adata)s
    %(backward)s
    %(lineage_drivers.parameters)s

    Returns
    -------
    %(lineage_drivers.returns)s

    References
    ----------
    %(lineage_drivers.references)s
    """

    # create dummy kernel and estimator
    pk = DummyKernel(adata, backward=backward)
    g = GPCCA(pk, read_from_adata=True, write_to_adata=False)
    if g._get(P.ABS_PROBS) is None:
        raise RuntimeError(
            f"Compute absorption probabilities first as `cellrank.tl.lineages(..., backward={backward})`."
        )

    # call the underlying function to compute and store the lineage drivers
    return g.compute_lineage_drivers(
        method=method,
        lineages=lineages,
        cluster_key=cluster_key,
        clusters=clusters,
        layer=layer,
        use_raw=use_raw,
        confidence_level=confidence_level,
        n_perms=n_perms,
        seed=seed,
        return_drivers=return_drivers,
        **kwargs,
    )
예제 #2
0
def lineages(
    adata: AnnData,
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    cluster_key: Optional[str] = None,
    mode: str = "embedding",
    time_key: str = "latent_time",
    **kwargs,
) -> None:
    """
    Plot lineages that were uncovered using :func:`cellrank.tl.lineages`.

    For each lineage, we show all cells in an embedding (default is UMAP) and color them by their probability of
    belonging to this lineage. For cells that are already committed, this probability will be one for  their respective
    lineage and zero otherwise. For naive cells, these probabilities will be more balanced, reflecting
    the fact that naive cells have the potential to develop towards multiple endpoints.

    Parameters
    ----------
    %(adata)s
    lineages
        Plot only these lineages. If `None`, plot all lineages.
    %(backward)s
    cluster_key
        If given, plot cluster annotations left of the lineage probabilities.
    %(time_mode)s
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(basis)s
    **kwargs
        Keyword arguments for :meth:`cellrank.tl.estimators.BaseEstimator.plot_absorption_probabilities`.

    Returns
    -------
    %(just_plots)s
    """

    pk = DummyKernel(adata, backward=backward)
    mc = GPCCA(pk, read_from_adata=True, write_to_adata=False)
    if mc._get(P.ABS_PROBS) is None:
        raise RuntimeError(
            f"Compute absorption probabilities first as `cellrank.tl.lineages(..., backward={backward})`."
        )

    # plot using the MC object
    mc.plot_absorption_probabilities(
        lineages=lineages,
        cluster_key=cluster_key,
        mode=mode,
        time_key=time_key,
        **kwargs,
    )
예제 #3
0
    def __init__(
        self,
        adata: AnnData,
        g: Union[str, np.ndarray],
        terminal_states: Optional[Union[str, pd.Series]] = None,
        cluster_key: Optional[str] = None,
        **kwargs: Any,
    ):
        if terminal_states is not None:
            dk = DummyKernel(adata, backward=False)
            estim = GPCCA(dk, write_to_adata=True)
            estim.set_terminal_states(terminal_states, cluster_key=cluster_key)

        try:
            super().__init__(adata, g=g, **kwargs)
        except Exception as e:  # noqa: B902
            raise RuntimeError("Unable to initialize the kernel.") from e
예제 #4
0
def _initial_terminal(
    adata: AnnData,
    backward: bool = False,
    discrete: bool = False,
    states: Optional[Union[str, Sequence[str]]] = None,
    cluster_key: Optional[str] = None,
    mode: str = "embedding",
    time_key: str = "latent_time",
    **kwargs,
) -> None:

    pk = DummyKernel(adata=adata, backward=backward)
    mc = GPCCA(pk, read_from_adata=True, write_to_adata=False)

    if mc._get(P.FIN) is None:
        raise RuntimeError(
            f"Compute {_initial if backward else _terminal} states first as "
            f"`cellrank.tl.compute_{FinalStatesKey.BACKWARD if backward else FinalStatesKey.FORWARD}()`."
        )

    n_states = len(mc._get(P.FIN).cat.categories)
    if n_states == 1 or (
        states is not None and (isinstance(states, str) or len(states) == 1)
    ):
        kwargs["same_plot"] = True

    if kwargs.get("title", None) is None:
        if discrete:
            if kwargs.get("same_plot", True):
                kwargs["title"] = (
                    FinalStatesPlot.BACKWARD.s
                    if backward
                    else FinalStatesPlot.FORWARD.s
                )
        elif (
            mode == "embedding"
            and kwargs.get("title", None) is None
            and (
                kwargs.get("same_plot", True)
                and n_states > 1
                and (
                    states is None or (not isinstance(states, str) and len(states) > 1)
                )
            )
        ):
            kwargs["title"] = (
                FinalStatesPlot.BACKWARD.s if backward else FinalStatesPlot.FORWARD.s
            )

    _ = kwargs.pop("lineages", None)

    mc.plot_final_states(
        lineages=states,
        cluster_key=cluster_key,
        mode=mode,
        time_key=time_key,
        discrete=discrete,
        **kwargs,
    )
예제 #5
0
def lineage_drivers(
    adata: AnnData,
    lineage: str,
    backward: bool = False,
    n_genes: int = 8,
    ncols: Optional[int] = None,
    use_raw: bool = False,
    title_fmt: str = "{gene} qval={qval:.4e}",
    **kwargs,
) -> None:
    """
    Plot lineage drivers that were uncovered using :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(backward)s
    %(plot_lineage_drivers.parameters)s

    Returns
    -------
    %(just_plots)s
    """

    pk = DummyKernel(adata, backward=backward)
    mc = GPCCA(pk, read_from_adata=True, write_to_adata=False)

    if use_raw and adata.raw is None:
        logg.warning("No raw attribute set. Using `adata.var` instead")
        use_raw = False

    direction = DirPrefix.BACKWARD if backward else DirPrefix.FORWARD
    needle = f"{direction} {lineage} corr"

    haystack = adata.raw.var if use_raw else adata.var

    if needle not in haystack:
        raise RuntimeError(
            f"Unable to find lineage drivers in "
            f"`{'adata.raw.var' if use_raw else 'adata.var'}[{needle!r}]`. "
            f"Compute lineage drivers first as `cellrank.tl.lineage_drivers(lineages={lineage!r}, "
            f"use_raw={use_raw}, backward={backward}).`")

    drivers = pd.DataFrame(haystack[[needle, f"{direction} {lineage} qval"]])
    drivers.columns = [f"{lineage} corr", f"{lineage} qval"]
    mc._set(A.LIN_DRIVERS, drivers)

    mc.plot_lineage_drivers(
        lineage,
        n_genes=n_genes,
        use_raw=use_raw,
        ncols=ncols,
        title_fmt=title_fmt,
        **kwargs,
    )
예제 #6
0
def lineages(
    adata: AnnData,
    backward: bool = False,
    copy: bool = False,
    return_estimator: bool = False,
    **kwargs,
) -> Optional[AnnData]:
    """
    Compute probabilistic lineage assignment using RNA velocity.

    For each cell `i` in :math:`{1, ..., N}` and %(initial_or_terminal)s state `j` in :math:`{1, ..., M}`,
    the probability is computed that cell `i` is either going to %(terminal)s state `j` (``backward=False``)
    or is coming from %(initial)s state `j` (``backward=True``).

    This function computes the absorption probabilities of a Markov chain towards the %(initial_or_terminal) states
    uncovered by :func:`cellrank.tl.initial_states` or :func:`cellrank.tl.terminal_states` using a highly efficient
    implementation that scales to large cell numbers.

    It's also possible to calculate mean and variance of the time until absorption for all or just a subset
    of the %(initial_or_terminal)s states. This can be seen as a pseudotemporal measure, either towards any terminal
    population of the state change trajectory, or towards specific ones.

    Parameters
    ----------
    %(adata)s
    %(backward)s
    copy
        Whether to update the existing ``adata`` object or to return a copy.
    return_estimator
        Whether to return the estimator. Only available when ``copy=False``.
    **kwargs
        Keyword arguments for :meth:`cellrank.tl.estimators.BaseEstimator.compute_absorption_probabilities`.

    Returns
    -------
    :class:`anndata.AnnData`, :class:`cellrank.tl.estimators.BaseEstimator` or :obj:`None`
        Depending on ``copy`` and ``return_estimator``, either updates the existing ``adata`` object,
        returns its copy or returns the estimator.
    """

    if backward:
        lin_key = AbsProbKey.BACKWARD
        fs_key = TermStatesKey.BACKWARD
        fs_key_pretty = TerminalStatesPlot.BACKWARD
    else:
        lin_key = AbsProbKey.FORWARD
        fs_key = TermStatesKey.FORWARD
        fs_key_pretty = TerminalStatesPlot.FORWARD

    try:
        pk = PrecomputedKernel(adata=adata, backward=backward)
    except KeyError as e:
        raise RuntimeError(
            f"Compute transition matrix first as `cellrank.tl.transition_matrix(..., backward={backward})`."
        ) from e

    start = logg.info(
        f"Computing lineage probabilities towards {fs_key_pretty.s}")
    mc = GPCCA(
        pk, read_from_adata=True, inplace=not copy
    )  # GPCCA is more general than CFLARE, in terms of what is saves
    if mc._get(P.TERM) is None:
        raise RuntimeError(
            f"Compute the states first as `cellrank.tl.{fs_key.s}(..., backward={backward})`."
        )

    # compute the absorption probabilities
    mc.compute_absorption_probabilities(**kwargs)

    logg.info(f"Adding lineages to `adata.obsm[{lin_key.s!r}]`\n    Finish",
              time=start)

    return mc.adata if copy else mc if return_estimator else None
예제 #7
0
def _create_gpcca(*, backward: bool = False) -> Tuple[AnnData, GPCCA]:
    adata = _adata_medium.copy()

    sc.tl.paga(adata, groups="clusters")

    vk = VelocityKernel(
        adata, backward=backward).compute_transition_matrix(softmax_scale=4)
    ck = ConnectivityKernel(adata,
                            backward=backward).compute_transition_matrix()
    final_kernel = 0.8 * vk + 0.2 * ck

    mc = GPCCA(final_kernel)

    mc.compute_partition()
    mc.compute_eigendecomposition()
    mc.compute_schur(method="krylov")
    mc.compute_macrostates(n_states=2)
    mc.set_terminal_states_from_macrostates()
    mc.compute_absorption_probabilities()
    mc.compute_lineage_drivers(cluster_key="clusters", use_raw=False)

    assert adata is mc.adata
    if backward:
        assert str(AbsProbKey.BACKWARD) in adata.obsm
    else:
        assert str(AbsProbKey.FORWARD) in adata.obsm
    np.testing.assert_allclose(mc.absorption_probabilities.X.sum(1),
                               1.0,
                               rtol=1e-6)

    return adata, mc
예제 #8
0
for key in samples:
    scv.tl.velocity_graph(samples[key],
                          mode_neighbors='connectivities',
                          compute_uncertainties=True)

# Forward direction (final states)
outdir = 'results/trajectory/cellrank/forward'
if not os.path.exists(outdir):
    os.makedirs(outdir)

scv.settings.figdir = outdir

for key in samples:
    vk = VelocityKernel(samples[key])
    vk.compute_transition_matrix(softmax_scale=None)
    g = GPCCA(vk)
    g.compute_schur(n_components=20)
    g.plot_spectrum(real_only=False, save="{}_eigenvalues.png".format(key))
    if key == "H508_EV" or key == "HT29_EV":
        g.plot_schur(use=4,
                     cluster_key=clusters,
                     show=False,
                     dpi=300,
                     save='{}_schur.png'.format(key))
        g.compute_metastable_states(n_states=4, cluster_key=clusters)
        g.plot_metastable_states(show=False,
                                 dpi=300,
                                 save='{}_metastable.png'.format(key))
        g.plot_metastable_states(
            same_plot=False,
            show=False,