예제 #1
0
def _assert_has_all_keys(adata: AnnData, direction: Direction):
    assert _transition(direction) in adata.uns.keys()

    if direction == Direction.FORWARD:
        assert str(LinKey.FORWARD) in adata.obsm
        assert isinstance(adata.obsm[str(LinKey.FORWARD)], cr.tl.Lineage)

        assert _colors(LinKey.FORWARD) in adata.uns.keys()
        assert _lin_names(LinKey.FORWARD) in adata.uns.keys()

        assert str(StateKey.FORWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(StateKey.FORWARD)])

        assert _probs(StateKey.FORWARD) in adata.obs
    else:
        assert str(LinKey.BACKWARD) in adata.obsm
        assert isinstance(adata.obsm[str(LinKey.BACKWARD)], cr.tl.Lineage)

        assert _colors(LinKey.BACKWARD) in adata.uns.keys()
        assert _lin_names(LinKey.BACKWARD) in adata.uns.keys()

        assert str(StateKey.BACKWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(StateKey.BACKWARD)])

        assert _probs(StateKey.BACKWARD) in adata.obs
예제 #2
0
    def write_to_adata(self, key_added: Optional[str] = None):
        """
        Write the parameters and transition matrix to the underlying adata object.

        Params
        ------
        key_added
            Postfix to be added to :paramref`.adata` `.uns.

        Returns
        -------
        None
            Updates the underlying :paramref:`.adata` object with the following:
                - `.uns[:paramref:`T_{fwd, bwd}` _`:paramref:`key_added`]['T']` - transition matrix
                - `.uns[:paramref:`T_{fwd, bwd}` _`:paramref:`key_added`]['params']` - parameters used for calculation
        """

        if self.transition_matrix is None:
            raise ValueError(
                "Compute transition matrix first as `.compute_transition_matrix()`.`"
            )

        key = _transition(self._direction)
        if key_added is not None:
            key += f"_{key_added}"

        if self.adata.uns.get(key, None) is not None:
            logg.debug(f"DEBUG: Overwriting key `{key!r}` in `adata.uns`")

        self.adata.uns[key] = dict()
        self.adata.uns[key]["params"] = str(self)
        self.adata.uns[key]["T"] = self.transition_matrix

        logg.debug(f"DEBUG: Added `{key!r}` to `adata.uns`")
예제 #3
0
    def test_backward_manual_dense_norm(self, adata):
        backward = True
        vk = VelocityKernel(adata,
                            backward=backward).compute_transition_matrix(
                                density_normalize=False)
        ck = ConnectivityKernel(adata,
                                backward=backward).compute_transition_matrix(
                                    density_normalize=False)

        # combine the kernels
        comb = 0.8 * vk + 0.2 * ck
        T_1 = comb.transition_matrix
        conn = _get_neighs(adata, "connectivities")
        T_1 = density_normalization(T_1, conn)
        T_1 = _normalize(T_1)

        transition_matrix(
            adata,
            diff_kernel="sum",
            weight_diffusion=0.2,
            density_normalize=True,
            backward=backward,
        )
        T_2 = adata.uns[_transition(Direction.BACKWARD)]["T"]

        np.testing.assert_allclose(T_1.A, T_2.A, rtol=_rtol)
예제 #4
0
    def test_transition_backward_differ_dense_norm(self, adata):
        backward = True

        vk = VelocityKernel(adata,
                            backward=backward).compute_transition_matrix(
                                density_normalize=True)
        T_1 = vk.transition_matrix

        transition_matrix(adata, density_normalize=False, backward=backward)
        T_2 = adata.uns[_transition(Direction.BACKWARD)]["T"]

        assert not np.allclose(T_1.A, T_2.A, rtol=_rtol)
예제 #5
0
    def test_transition_forward_dense_norm(self, adata):
        backward = False

        vk = VelocityKernel(adata,
                            backward=backward).compute_transition_matrix(
                                density_normalize=False)
        T_1 = vk.transition_matrix

        transition_matrix(adata, density_normalize=False, backward=backward)
        T_2 = adata.uns[_transition(Direction.FORWARD)]["T"]

        np.testing.assert_allclose(T_1.A, T_2.A, rtol=_rtol)
예제 #6
0
    def test_foward(self, adata):
        density_normalize = False
        vk = VelocityKernel(adata).compute_transition_matrix(
            density_normalize=density_normalize)
        ck = ConnectivityKernel(adata).compute_transition_matrix(
            density_normalize=density_normalize)

        comb = 0.8 * vk + 0.2 * ck
        T_1 = comb.transition_matrix

        transition_matrix(adata,
                          diff_kernel="sum",
                          weight_diffusion=0.2,
                          density_normalize=False)
        T_2 = adata.uns[_transition(Direction.FORWARD)]["T"]

        np.testing.assert_allclose(T_1.A, T_2.A, rtol=_rtol)
예제 #7
0
    def test_backward_negate(self, adata):
        backward = True
        dense_norm = False
        backward_mode = "negate"
        vk = VelocityKernel(adata, backward=backward)

        vk.compute_transition_matrix(density_normalize=dense_norm,
                                     backward_mode=backward_mode)
        T_1 = vk.transition_matrix
        transition_matrix(
            adata,
            density_normalize=dense_norm,
            backward=backward,
            backward_mode=backward_mode,
        )
        T_2 = adata.uns[_transition(Direction.BACKWARD)]["T"]

        np.testing.assert_allclose(T_1.A, T_2.A, rtol=_rtol)
예제 #8
0
    def test_backward(self, adata):
        density_norm, backward = False, True
        vk = VelocityKernel(adata,
                            backward=backward).compute_transition_matrix(
                                density_normalize=density_norm)
        ck = ConnectivityKernel(adata,
                                backward=backward).compute_transition_matrix(
                                    density_normalize=density_norm)

        # combine the kernels
        comb = 0.8 * vk + 0.2 * ck
        T_1 = comb.transition_matrix

        transition_matrix(
            adata,
            diff_kernel="sum",
            weight_diffusion=0.2,
            density_normalize=False,
            backward=backward,
        )
        T_2 = adata.uns[_transition(Direction.BACKWARD)]["T"]

        np.testing.assert_allclose(T_1.A, T_2.A, rtol=_rtol)
예제 #9
0
def transition_matrix(
    adata: AnnData,
    vkey: str = "velocity",
    backward: bool = False,
    self_transitions: Optional[str] = None,
    sigma_corr: Optional[float] = None,
    diff_kernel: Optional[str] = None,
    weight_diffusion: float = 0.2,
    density_normalize: bool = True,
    backward_mode: str = "transpose",
    inplace: bool = True,
) -> csr_matrix:
    """
    Computes transition probabilities from velocity graph.

    THIS FUNCTION HAS BEEN DEPRECATED.
    Interact with kernels via the Kernel class or via cellrank.tools_transition_matrix.transition_matrix

    Employs ideas of both scvelo as well as velocyto.

    Parameters
    --------
    adata : :class:`anndata.AnnData`
        Annotated Data Matrix
    vkey
        Name of the velocity estimates to be used
    backward
        Whether to use the transition matrix to push forward (`False`) or to pull backward (`True`)
    self_transitions
        How to fill the diagonal. Can be either 'velocyto' or 'scvelo'. Two diffent
        heuristics are used. Can prevent dividing by zero in unlucky sitatuations for the
        reverse process
    sigma_corr
        Kernel width for exp kernel to be used to compute transition probabilities
        from the velocity graph. If None, the median cosine correlation of all
        potisive cosine correlations will be used.
    diff_kernel
        Whether to multiply the velocity connectivities with transcriptomic distances to make them more robust.
        Options are ('sum', 'mult', 'both')
    weight_diffusion
        Relative weight given to the diffusion kernel. Must be in [0, 1]. Only matters when using 'sum' or 'both'
        for the diffusion kernel.
    density_normalize
        Whether to use the transcriptomic KNN graph for density normalization as performed in scanpy when
        computing diffusion maps
    backward_mode
        Options are ['transpose', 'negate'].
    inplace
        If True, adds to adata. Otherwise returns.

    Returns
    --------
    T: :class:`scipy.sparse.csr_matrix`
        Transition matrix
    """
    logg.info("Computing transition probability from velocity graph")

    from datetime import datetime

    print(datetime.now())

    # get the direction of the process
    direction = Direction.BACKWARD if backward else Direction.FORWARD

    # get the velocity correlations
    if (vkey + "_graph" not in adata.uns.keys()) or (vkey + "_graph_neg"
                                                     not in adata.uns.keys()):
        raise ValueError(
            "You need to run `tl.velocity_graph` first to compute cosine correlations"
        )
    velo_corr, velo_corr_neg = (
        csr_matrix(adata.uns[vkey + "_graph"]).copy(),
        csr_matrix(adata.uns[vkey + "_graph_neg"]).copy(),
    )
    velo_corr_comb_ = (velo_corr + velo_corr_neg).astype(np.float64)
    if backward:
        if backward_mode == "negate":
            velo_corr_comb = velo_corr_comb_.multiply(-1)
        elif backward_mode == "transpose":
            velo_corr_comb = velo_corr_comb_.T
        else:
            raise ValueError(f"Unknown backward_mode `{backward_mode}`.")
    else:
        velo_corr_comb = velo_corr_comb_
    med_corr = np.median(np.abs(velo_corr_comb.data))

    # compute the raw transition matrix. At the moment, this is just an exponential kernel
    logg.debug("DEBUG: Computing the raw transition matrix")
    if sigma_corr is None:
        sigma_corr = 1 / med_corr
    velo_graph = velo_corr_comb.copy()
    velo_graph.data = np.exp(velo_graph.data * sigma_corr)

    # should I row-_normalize the transcriptomic connectivities?
    if diff_kernel is not None or density_normalize:
        params = _get_neighs_params(adata)
        logg.debug(
            f'DEBUG: Using KNN graph computed in basis {params.get("use_rep", "Unknown")!r} '
            'with {params["n_neighbors"]} neighbors')
        trans_graph = _get_neighs(adata, "connectivities")
        dev = norm((trans_graph - trans_graph.T), ord="fro")
        if dev > 1e-4:
            logg.warning("KNN base graph not symmetric, `dev={dev}`")

    # KNN smoothing
    if diff_kernel is not None:
        logg.debug("DEBUG: Smoothing KNN graph with diffusion kernel")
        velo_graph = _knn_smooth(diff_kernel, velo_graph, trans_graph,
                                 weight_diffusion)
    # return velo_graph

    # set the diagonal elements. This is important especially for the backwards direction
    logg.debug("DEBUG: Setting diagonal elements")
    velo_graph = _self_loops(self_transitions, velo_graph)

    # density normalisation - taken from scanpy
    if density_normalize:
        logg.debug("DEBUG: Density correcting the velocity graph")
        velo_graph = density_normalization(velo_graph, trans_graph)

    # normalize
    T = _normalize(velo_graph)

    if not inplace:
        logg.info("Computed transition matrix")
        return T

    if _transition(direction) in adata.uns.keys():
        logg.warning(
            f"`.uns` already contains a field `{_transition(direction)!r}`. Overwriting"
        )

    params = {
        "backward": backward,
        "self_transitions": self_transitions,
        "sigma_corr": np.round(sigma_corr, 3),
        "diff_kernel": diff_kernel,
        "weight_diffusion": weight_diffusion,
        "density_normalize": density_normalize,
    }

    adata.uns[_transition(direction)] = {"T": T, "params": params}
    logg.info(
        f"Computed transition matrix and added the key `{_transition(direction)!r}` to `adata.uns`"
    )
예제 #10
0
def lineages(
    adata: AnnData,
    final: bool = True,
    keys: Optional[Sequence[str]] = None,
    copy: bool = False,
) -> Optional[AnnData]:
    """
    Computes probabilistic lineage assignment using RNA velocity.

    For each cell i in {1, ..., N} and start/endpoint j in {1, ..., M}, the probability is computed that cell i
    is either going to j (end point) or coming from j (start point). Mathematically, this computes absorption
    probabilities to approximate recurrent classes using an RNA velocity based Markov chain.

    Note that absorption probabilities have been used in the single cell context to infer lineage probabilities e.g.
    in [Setty19]_ or [Weinreb18]_ and we took inspiration from there.

    Before running this function, compute start/endpoints using :func:`cellrank.tl.root_final`.

    Parameters
    --------
    adata : :class:`anndata.AnnData`
        Annotated data object
    final
        If `True`, computes final cells, i.e. end points. Otherwise, computes root cells, i.e. starting points.
    keys
        Determines which end/start-points to use by passing their names. Further, start/end-points can be combined.
        If e.g. the endpoints are ['Neuronal_1', 'Neuronal_1', 'Astrocytes', 'OPC'], then passing
        keys=['Neuronal_1, Neuronal_2', 'OPC'] means that the two neuronal endpoints are treated as one and
        Astrocytes are excluded.
    copy
        Whether to update the existing AnnData object or to return a copy.

    Returns
    --------
    :class:`anndata.AnnData` or :class:`NoneType`
        Depending on :paramref:`copy`, either updates the existing :paramref:`adata` object or returns a copy.
    """

    # Set the keys and print info
    adata = adata.copy() if copy else adata

    if final:
        direction = Direction.FORWARD
        lin_key = LinKey.FORWARD
        rc_key = RcKey.FORWARD
    else:
        direction = Direction.BACKWARD
        lin_key = LinKey.BACKWARD
        rc_key = RcKey.BACKWARD

    transition_key = _transition(direction)
    if transition_key not in adata.uns.keys():
        raise ValueError(
            f"Compute {'final' if final else 'root'} cells first as `cellrank.tl.find_{'final' if final else 'root'}`."
        )

    start = logg.info(f"Computing lineage probabilities towards `{rc_key}`")

    # get the transition matrix from the AnnData object and initialise MC object
    vk = VelocityKernel(adata, backward=not final)
    vk.transition_matrix = adata.uns[transition_key]["T"]
    mc = MarkovChain(vk)

    # compute the absorption probabilities
    mc.compute_lin_probs(keys=keys)

    logg.info(f"Added key `{lin_key!r}` to `adata.obsm`\n    Finish",
              time=start)

    return adata if copy else None
예제 #11
0
def lineages(
    adata: AnnData,
    estimator: type(BaseEstimator) = GPCCA,
    final: bool = True,
    cluster_key: Optional[str] = None,
    keys: Optional[Sequence[str]] = None,
    n_lineages: Optional[int] = None,
    method: str = "krylov",
    copy: bool = False,
    return_estimator: bool = False,
    **kwargs,
) -> Optional[AnnData]:
    """
    Compute probabilistic lineage assignment using RNA velocity.

    For each cell i in {1, ..., N} and root/final state j in {1, ..., M}, the probability is computed that cell i
    is either going to final state j (`final=True`) or coming from root state j (`final=False`). We provide two
    estimators for computing these probabilities:

    For the estimator GPCCA, we perform Generalized Perron Cluster Cluster Analysis [GPCCA18]_.  Cells are mapped to a
    simplex where each corner represents a final/root state, and the position of a cell in the simplex determines its
    probability of going to a final states/coming from a root state.

    For the estimator CFLARE, we compute absorption probabilities towards the root/final states of the Markov chain.
    For related approaches in the single cell context that utilise absorption probabilities to map cells to lineages,
    see [Setty19]_ or [Weinreb18]_.

    Before running this function, compute root/final states using :func:`cellrank.tl.root_states` or
    :func:`cellrank.tl.final_states`, respectively.

    Parameters
    --------
    adata : :class:`anndata.AnnData`
        Annotated data object
    estimator
        Estimator to use to compute the lineage probabilities.
    final
        If `True`, computes final states. Otherwise, computes root states.
    cluster_key
        Match computed {direction} states against pre-computed clusters to annotate the {direction} states.
        For this, provide a key from :paramref:`adata` `.obs` where cluster labels have been computed.
    keys
        Determines which root/final states to use by passing their names. Further, root/final states can be combined.
        If e.g. the final states are ['Neuronal_1', 'Neuronal_1', 'Astrocytes', 'OPC'], then passing
        keys=['Neuronal_1, Neuronal_2', 'OPC'] means that the two neuronal final states are treated as one and the
        Astrocyte state is excluded.
    n_lineages
        Number of lineages when :paramref:`estimator` `=GPCCA`. If `None`, it will be based on `eigengap`.
    method
        Method to use when computing the Schur decomposition. Only needed when :paramref:`estimator`
        is :class`:cellrank.tl.GPCCA:.
        Valid options are: `'krylov'`, `'brandts'`.
    copy
        Whether to update the existing AnnData object or to return a copy.
    return_estimator
        Whether to return the estimator. Only available when :paramref:`copy=False`.
    kwargs
        Keyword arguments for :meth:`cellrank.tl.estimators.BaseEstimator.compute_metastable_states`.

    Returns
    --------
    :class:`anndata.AnnData`, :class:`cellrank.tools.estimators.BaseEstimator` or :class:`NoneType`
        Depending on :paramref:`copy`, either updates the existing :paramref:`adata` object or returns a copy or
        returns the estimator.
    """

    if not isinstance(estimator, type):
        raise TypeError(
            f"Expected estimator to be a class, found `{type(estimator).__name__}`."
        )

    if not issubclass(estimator, BaseEstimator):
        raise TypeError(
            f"Expected estimator to be a subclass of `BaseEstimator`, found `{type(estimator).__name__}`"
        )

    # Set the keys and print info
    adata = adata.copy() if copy else adata

    if final:
        direction = Direction.FORWARD
        lin_key = LinKey.FORWARD
        rc_key = StateKey.FORWARD
    else:
        direction = Direction.BACKWARD
        lin_key = LinKey.BACKWARD
        rc_key = StateKey.BACKWARD

    transition_key = _transition(direction)
    if transition_key not in adata.uns.keys():
        key = "final" if final else "root"
        raise ValueError(
            f"Compute {key} states first as `cellrank.tl.find_{key}`.")

    start = logg.info(f"Computing lineage probabilities towards `{rc_key}`")

    # get the transition matrix from the AnnData object and initialise MC object
    vk = VelocityKernel(adata, backward=not final)
    vk.transition_matrix = adata.uns[transition_key]["T"]
    mc = estimator(vk, read_from_adata=False)

    if cluster_key is None:
        _info_if_obs_keys_categorical_present(
            adata,
            keys=["louvain", "clusters"],
            msg_fmt="Found categorical observation in `adata.obs[{!r}]`. "
            "Consider specifying it as `cluster_key`.",
        )

    # compute the absorption probabilities
    if isinstance(mc, CFLARE):
        mc.compute_eig()
        mc.compute_metastable_states(cluster_key=cluster_key, **kwargs)
        mc.compute_lin_probs(keys=keys)
    elif isinstance(mc, GPCCA):
        if n_lineages is None or n_lineages == 1:
            mc.compute_eig()
            if n_lineages is None:
                n_lineages = mc.eigendecomposition["eigengap"] + 1

        if n_lineages > 1:
            mc.compute_schur(n_lineages + 1, method=method)

        try:
            mc.compute_metastable_states(n_states=n_lineages,
                                         cluster_key=cluster_key,
                                         **kwargs)
        except ValueError:
            logg.warning(
                f"Computing {n_lineages} metastable states cuts through a block of complex conjugates. "
                f"Increasing `n_lineages` to {n_lineages + 1}")
            mc.compute_metastable_states(n_states=n_lineages + 1,
                                         cluster_key=cluster_key,
                                         **kwargs)
        mc.set_main_states(names=keys)
    else:
        raise NotImplementedError(
            f"Pipeline not implemented for `{type(bytes).__name__}`")

    logg.info(f"Added key `{lin_key!r}` to `adata.obsm`\n    Finish",
              time=start)

    return adata if copy else mc if return_estimator else None