def write_to_adata(self, key: Optional[str] = None) -> None: """ Write the transition matrix and parameters used for computation to the underlying :attr:`adata` object. Parameters ---------- key Key used when writing transition matrix to :attr:`adata`. If `None`, the ``key`` is set to `'T_bwd'` if :attr:`backward` is `True`, else `'T_fwd'`. Returns ------- None %(write_to_adata)s """ if self._transition_matrix is None: raise ValueError( "Compute transition matrix first as `.compute_transition_matrix()`." ) if key is None: key = _transition(self._direction) # retain the embedding info self.adata.uns[f"{key}_params"] = { **self.adata.uns.get(f"{key}_params", {}), **{ "params": self.params }, } _write_graph_data(self.adata, self.transition_matrix, key)
def _assert_has_all_keys(adata: AnnData, direction: Direction): assert _transition(direction) in adata.obsp.keys() # check if it's not a dummy transition matrix assert not np.all( np.isclose(np.diag(adata.obsp[_transition(direction)].A), 1.0)) assert f"{_transition(direction)}_params" in adata.uns.keys() if direction == Direction.FORWARD: assert str(AbsProbKey.FORWARD) in adata.obsm assert isinstance(adata.obsm[str(AbsProbKey.FORWARD)], cr.tl.Lineage) assert _colors(AbsProbKey.FORWARD) in adata.uns.keys() assert _lin_names(AbsProbKey.FORWARD) in adata.uns.keys() assert str(TermStatesKey.FORWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(TermStatesKey.FORWARD)]) assert _probs(TermStatesKey.FORWARD) in adata.obs # check the correlations with all lineages have been computed lin_probs = adata.obsm[str(AbsProbKey.FORWARD)] np.in1d( [f"{str(DirPrefix.FORWARD)} {key}" for key in lin_probs.names], adata.var.keys(), ).all() else: assert str(AbsProbKey.BACKWARD) in adata.obsm assert isinstance(adata.obsm[str(AbsProbKey.BACKWARD)], cr.tl.Lineage) assert _colors(AbsProbKey.BACKWARD) in adata.uns.keys() assert _lin_names(AbsProbKey.BACKWARD) in adata.uns.keys() assert str(TermStatesKey.BACKWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(TermStatesKey.BACKWARD)]) assert _probs(TermStatesKey.BACKWARD) in adata.obs # check the correlations with all lineages have been computed lin_probs = adata.obsm[str(AbsProbKey.BACKWARD)] np.in1d( [f"{str(DirPrefix.BACKWARD)} {key}" for key in lin_probs.names], adata.var.keys(), ).all()
def __init__( self, transition_matrix: Optional[Union[np.ndarray, spmatrix, str]] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, ): from anndata import AnnData as _AnnData if transition_matrix is None: transition_matrix = _transition( Direction.BACKWARD if backward else Direction.FORWARD) logg.debug( f"Setting transition matrix key to `{transition_matrix!r}`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) transition_matrix = _read_graph_data(adata, transition_matrix) if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__!r}`.") if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError( "Not a valid transition matrix: not all rows sum to 1.") if adata is None: logg.warning("Creating empty `AnnData` object") adata = _AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32)) super().__init__(adata, backward=backward, compute_cond_num=compute_cond_num) self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num()
def __init__( self, transition_matrix: Optional[ Union[np.ndarray, spmatrix, KernelExpression, str] ] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, **kwargs: Any, ): from anndata import AnnData as _AnnData self._origin = "'array'" params = {} if transition_matrix is None: transition_matrix = _transition( Direction.BACKWARD if backward else Direction.FORWARD ) logg.debug(f"Setting transition matrix key to `{transition_matrix!r}`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) self._origin = f"adata.obsp[{transition_matrix!r}]" transition_matrix = _read_graph_data(adata, transition_matrix) elif isinstance(transition_matrix, KernelExpression): if transition_matrix._transition_matrix is None: raise ValueError( "Compute transition matrix first as `.compute_transition_matrix()`." ) if adata is not None and adata is not transition_matrix.adata: logg.warning( "Ignoring supplied `adata` object because it differs from the kernel's `adata` object." ) # use `str` because it captures the params self._origin = str(transition_matrix).strip("~<>") params = transition_matrix.params.copy() backward = transition_matrix.backward adata = transition_matrix.adata transition_matrix = transition_matrix.transition_matrix if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__!r}`." ) if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError("Not a valid transition matrix, not all rows sum to 1") if adata is None: logg.warning("Creating empty `AnnData` object") adata = _AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32) ) super().__init__( adata, backward=backward, compute_cond_num=compute_cond_num, **kwargs ) self._params = params self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num()
def test_states_no_precomputed_transition_matrix(self, adata: AnnData): cr.tl.terminal_states(adata, key="foo") assert str(_transition(Direction.FORWARD)) in adata.obsp
def compute_projection( self, basis: str = "umap", key_added: Optional[str] = None, copy: bool = False, ) -> Optional[np.ndarray]: """ Compute a projection of the transition matrix in the embedding. The projected matrix can be then visualized as:: scvelo.pl.velocity_embedding(adata, vkey='T_fwd', basis='umap') Parameters ---------- basis Basis in :attr:`adata` ``.obsm`` for which to compute the projection. key_added If not `None` and ``copy=False``, save the result to :attr:`adata` ``.obsm['{key_added}']``. Otherwise, save the result to `'T_fwd_{basis}'` or `T_bwd_{basis}`, depending on the direction. copy Whether to return the projection or modify :attr:`adata` inplace. Returns ------- If ``copy=True``, the projection array of shape `(n_cells, n_components)`. Otherwise, it modifies :attr:`anndata.AnnData.obsm` with a key based on ``key_added``. """ # modified from: https://github.com/theislab/scvelo/blob/master/scvelo/tools/velocity_embedding.py from scvelo.tools.velocity_embedding import quiver_autoscale if self._transition_matrix is None: raise RuntimeError( "Compute transition matrix first as `.compute_transition_matrix()`." ) start = logg.info(f"Projecting transition matrix onto `{basis}`") emb = _get_basis(self.adata, basis) T_emb = np.empty_like(emb) with warnings.catch_warnings(): warnings.simplefilter("ignore") for i, row in enumerate(self.transition_matrix): dX = emb[row.indices] - emb[i, None] if np.any(np.isnan(dX)): T_emb[i] = np.nan else: dX /= np.linalg.norm(dX, axis=1)[:, None] dX = np.nan_to_num(dX) probs = row.data T_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0) T_emb /= 3 * quiver_autoscale(np.nan_to_num(emb), T_emb) if copy: return T_emb key = _transition(self._direction) if key_added is None else key_added ukey = f"{key}_params" embs = self.adata.uns.get(ukey, {}).get("embeddings", []) if basis not in embs: embs = list(embs) + [basis] self.adata.uns[ukey] = self.adata.uns.get(ukey, {}) self.adata.uns[ukey]["embeddings"] = embs key = key + "_" + basis logg.info( f"Adding `adata.obsm[{key!r}]`\n Finish", time=start, ) self.adata.obsm[key] = T_emb