コード例 #1
0
    def write_to_adata(self, key: Optional[str] = None) -> None:
        """
        Write the transition matrix and parameters used for computation to the underlying :attr:`adata` object.

        Parameters
        ----------
        key
            Key used when writing transition matrix to :attr:`adata`.
            If `None`, the ``key`` is set to `'T_bwd'` if :attr:`backward` is `True`, else `'T_fwd'`.

        Returns
        -------
        None
            %(write_to_adata)s
        """

        if self._transition_matrix is None:
            raise ValueError(
                "Compute transition matrix first as `.compute_transition_matrix()`."
            )

        if key is None:
            key = _transition(self._direction)

        # retain the embedding info
        self.adata.uns[f"{key}_params"] = {
            **self.adata.uns.get(f"{key}_params", {}),
            **{
                "params": self.params
            },
        }
        _write_graph_data(self.adata, self.transition_matrix, key)
コード例 #2
0
def _assert_has_all_keys(adata: AnnData, direction: Direction):
    assert _transition(direction) in adata.obsp.keys()
    # check if it's not a dummy transition matrix
    assert not np.all(
        np.isclose(np.diag(adata.obsp[_transition(direction)].A), 1.0))
    assert f"{_transition(direction)}_params" in adata.uns.keys()

    if direction == Direction.FORWARD:
        assert str(AbsProbKey.FORWARD) in adata.obsm
        assert isinstance(adata.obsm[str(AbsProbKey.FORWARD)], cr.tl.Lineage)

        assert _colors(AbsProbKey.FORWARD) in adata.uns.keys()
        assert _lin_names(AbsProbKey.FORWARD) in adata.uns.keys()

        assert str(TermStatesKey.FORWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(TermStatesKey.FORWARD)])

        assert _probs(TermStatesKey.FORWARD) in adata.obs

        # check the correlations with all lineages have been computed
        lin_probs = adata.obsm[str(AbsProbKey.FORWARD)]
        np.in1d(
            [f"{str(DirPrefix.FORWARD)} {key}" for key in lin_probs.names],
            adata.var.keys(),
        ).all()

    else:
        assert str(AbsProbKey.BACKWARD) in adata.obsm
        assert isinstance(adata.obsm[str(AbsProbKey.BACKWARD)], cr.tl.Lineage)

        assert _colors(AbsProbKey.BACKWARD) in adata.uns.keys()
        assert _lin_names(AbsProbKey.BACKWARD) in adata.uns.keys()

        assert str(TermStatesKey.BACKWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(TermStatesKey.BACKWARD)])

        assert _probs(TermStatesKey.BACKWARD) in adata.obs

        # check the correlations with all lineages have been computed
        lin_probs = adata.obsm[str(AbsProbKey.BACKWARD)]
        np.in1d(
            [f"{str(DirPrefix.BACKWARD)} {key}" for key in lin_probs.names],
            adata.var.keys(),
        ).all()
コード例 #3
0
    def __init__(
        self,
        transition_matrix: Optional[Union[np.ndarray, spmatrix, str]] = None,
        adata: Optional[AnnData] = None,
        backward: bool = False,
        compute_cond_num: bool = False,
    ):
        from anndata import AnnData as _AnnData

        if transition_matrix is None:
            transition_matrix = _transition(
                Direction.BACKWARD if backward else Direction.FORWARD)
            logg.debug(
                f"Setting transition matrix key to `{transition_matrix!r}`")

        if isinstance(transition_matrix, str):
            if adata is None:
                raise ValueError(
                    "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None."
                )
            transition_matrix = _read_graph_data(adata, transition_matrix)

        if not isinstance(transition_matrix, (np.ndarray, spmatrix)):
            raise TypeError(
                f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, "
                f"found `{type(transition_matrix).__name__!r}`.")

        if transition_matrix.shape[0] != transition_matrix.shape[1]:
            raise ValueError(
                f"Expected transition matrix to be square, found `{transition_matrix.shape}`."
            )

        if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL):
            raise ValueError(
                "Not a valid transition matrix: not all rows sum to 1.")

        if adata is None:
            logg.warning("Creating empty `AnnData` object")
            adata = _AnnData(
                csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32))

        super().__init__(adata,
                         backward=backward,
                         compute_cond_num=compute_cond_num)
        self._transition_matrix = csr_matrix(transition_matrix)
        self._maybe_compute_cond_num()
コード例 #4
0
    def __init__(
        self,
        transition_matrix: Optional[
            Union[np.ndarray, spmatrix, KernelExpression, str]
        ] = None,
        adata: Optional[AnnData] = None,
        backward: bool = False,
        compute_cond_num: bool = False,
        **kwargs: Any,
    ):
        from anndata import AnnData as _AnnData

        self._origin = "'array'"
        params = {}

        if transition_matrix is None:
            transition_matrix = _transition(
                Direction.BACKWARD if backward else Direction.FORWARD
            )
            logg.debug(f"Setting transition matrix key to `{transition_matrix!r}`")

        if isinstance(transition_matrix, str):
            if adata is None:
                raise ValueError(
                    "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None."
                )
            self._origin = f"adata.obsp[{transition_matrix!r}]"
            transition_matrix = _read_graph_data(adata, transition_matrix)

        elif isinstance(transition_matrix, KernelExpression):
            if transition_matrix._transition_matrix is None:
                raise ValueError(
                    "Compute transition matrix first as `.compute_transition_matrix()`."
                )
            if adata is not None and adata is not transition_matrix.adata:
                logg.warning(
                    "Ignoring supplied `adata` object because it differs from the kernel's `adata` object."
                )

            # use `str` because it captures the params
            self._origin = str(transition_matrix).strip("~<>")
            params = transition_matrix.params.copy()
            backward = transition_matrix.backward
            adata = transition_matrix.adata
            transition_matrix = transition_matrix.transition_matrix

        if not isinstance(transition_matrix, (np.ndarray, spmatrix)):
            raise TypeError(
                f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, "
                f"found `{type(transition_matrix).__name__!r}`."
            )

        if transition_matrix.shape[0] != transition_matrix.shape[1]:
            raise ValueError(
                f"Expected transition matrix to be square, found `{transition_matrix.shape}`."
            )

        if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL):
            raise ValueError("Not a valid transition matrix, not all rows sum to 1")

        if adata is None:
            logg.warning("Creating empty `AnnData` object")
            adata = _AnnData(
                csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32)
            )

        super().__init__(
            adata, backward=backward, compute_cond_num=compute_cond_num, **kwargs
        )

        self._params = params
        self._transition_matrix = csr_matrix(transition_matrix)
        self._maybe_compute_cond_num()
コード例 #5
0
    def test_states_no_precomputed_transition_matrix(self, adata: AnnData):
        cr.tl.terminal_states(adata, key="foo")

        assert str(_transition(Direction.FORWARD)) in adata.obsp
コード例 #6
0
    def compute_projection(
        self,
        basis: str = "umap",
        key_added: Optional[str] = None,
        copy: bool = False,
    ) -> Optional[np.ndarray]:
        """
        Compute a projection of the transition matrix in the embedding.

        The projected matrix can be then visualized as::

            scvelo.pl.velocity_embedding(adata, vkey='T_fwd', basis='umap')

        Parameters
        ----------
        basis
            Basis in :attr:`adata` ``.obsm`` for which to compute the projection.
        key_added
            If not `None` and ``copy=False``, save the result to :attr:`adata` ``.obsm['{key_added}']``.
            Otherwise, save the result to `'T_fwd_{basis}'` or `T_bwd_{basis}`, depending on the direction.
        copy
            Whether to return the projection or modify :attr:`adata` inplace.

        Returns
        -------
        If ``copy=True``, the projection array of shape `(n_cells, n_components)`.
        Otherwise, it modifies :attr:`anndata.AnnData.obsm` with a key based on ``key_added``.
        """
        # modified from: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_embedding.py
        from scvelo.tools.velocity_embedding import quiver_autoscale

        if self._transition_matrix is None:
            raise RuntimeError(
                "Compute transition matrix first as `.compute_transition_matrix()`."
            )

        start = logg.info(f"Projecting transition matrix onto `{basis}`")
        emb = _get_basis(self.adata, basis)
        T_emb = np.empty_like(emb)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i, row in enumerate(self.transition_matrix):
                dX = emb[row.indices] - emb[i, None]
                if np.any(np.isnan(dX)):
                    T_emb[i] = np.nan
                else:
                    dX /= np.linalg.norm(dX, axis=1)[:, None]
                    dX = np.nan_to_num(dX)
                    probs = row.data
                    T_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0)

        T_emb /= 3 * quiver_autoscale(np.nan_to_num(emb), T_emb)

        if copy:
            return T_emb

        key = _transition(self._direction) if key_added is None else key_added
        ukey = f"{key}_params"

        embs = self.adata.uns.get(ukey, {}).get("embeddings", [])
        if basis not in embs:
            embs = list(embs) + [basis]
            self.adata.uns[ukey] = self.adata.uns.get(ukey, {})
            self.adata.uns[ukey]["embeddings"] = embs

        key = key + "_" + basis
        logg.info(
            f"Adding `adata.obsm[{key!r}]`\n    Finish",
            time=start,
        )
        self.adata.obsm[key] = T_emb