def test_compute_lin_probs_normal_run(self, adata_large: AnnData): vk = VelocityKernel(adata_large).compute_transition_matrix() ck = ConnectivityKernel(adata_large).compute_transition_matrix() final_kernel = 0.8 * vk + 0.2 * ck mc = cr.tl.MarkovChain(final_kernel) mc.compute_eig(k=5) mc.compute_approx_rcs(use=2) mc.compute_lin_probs() assert isinstance(mc.diff_potential, np.ndarray) assert f"{LinKey.FORWARD}_dp" in mc.adata.obs.keys() np.testing.assert_array_equal(mc.diff_potential, mc.adata.obs[f"{LinKey.FORWARD}_dp"]) assert isinstance(mc.lineage_probabilities, cr.tl.Lineage) assert mc.lineage_probabilities.shape == (mc.adata.n_obs, 2) assert f"{LinKey.FORWARD}" in mc.adata.obsm.keys() np.testing.assert_array_equal(mc.lineage_probabilities.X, mc.adata.obsm[f"{LinKey.FORWARD}"]) assert _lin_names(LinKey.FORWARD) in mc.adata.uns.keys() np.testing.assert_array_equal(mc.lineage_probabilities.names, mc.adata.uns[_lin_names(LinKey.FORWARD)]) assert _colors(LinKey.FORWARD) in mc.adata.uns.keys() np.testing.assert_array_equal(mc.lineage_probabilities.colors, mc.adata.uns[_colors(LinKey.FORWARD)]) np.testing.assert_allclose(mc.lineage_probabilities.X.sum(1), 1)
def _assert_has_all_keys(adata: AnnData, direction: Direction): assert _transition(direction) in adata.uns.keys() if direction == Direction.FORWARD: assert str(LinKey.FORWARD) in adata.obsm assert isinstance(adata.obsm[str(LinKey.FORWARD)], cr.tl.Lineage) assert _colors(LinKey.FORWARD) in adata.uns.keys() assert _lin_names(LinKey.FORWARD) in adata.uns.keys() assert str(StateKey.FORWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(StateKey.FORWARD)]) assert _probs(StateKey.FORWARD) in adata.obs else: assert str(LinKey.BACKWARD) in adata.obsm assert isinstance(adata.obsm[str(LinKey.BACKWARD)], cr.tl.Lineage) assert _colors(LinKey.BACKWARD) in adata.uns.keys() assert _lin_names(LinKey.BACKWARD) in adata.uns.keys() assert str(StateKey.BACKWARD) in adata.obs assert is_categorical_dtype(adata.obs[str(StateKey.BACKWARD)]) assert _probs(StateKey.BACKWARD) in adata.obs
def _check_main_states(mc: cr.tl.GPCCA, has_main_states: bool = True): if has_main_states: assert isinstance(mc.main_states, pd.Series) assert_array_nan_equal(mc.adata.obs[str(StateKey.FORWARD)], mc.main_states) np.testing.assert_array_equal( mc.adata.uns[_colors(StateKey.FORWARD)], mc.lineage_probabilities[list(mc.main_states.cat.categories)].colors, ) assert isinstance(mc.diff_potential, np.ndarray) assert isinstance(mc.lineage_probabilities, cr.tl.Lineage) np.testing.assert_array_equal( mc.adata.obsm[str(LinKey.FORWARD)], mc.lineage_probabilities.X ) np.testing.assert_array_equal( mc.adata.uns[_lin_names(LinKey.FORWARD)], mc.lineage_probabilities.names ) np.testing.assert_array_equal( mc.adata.uns[_colors(LinKey.FORWARD)], mc.lineage_probabilities.colors ) np.testing.assert_array_equal(mc.adata.obs[_dp(LinKey.FORWARD)], mc.diff_potential) np.testing.assert_array_equal( mc.adata.obs[_probs(StateKey.FORWARD)], mc.main_states_probabilities )
def test_non_unique_names(self, adata: AnnData, path: Path, lin_key: str, _: int): names_key = _lin_names(lin_key) adata.uns[names_key][0] = adata.uns[names_key][1] sc.write(path, adata) with pytest.raises(ValueError): _ = cr.read(path)
def test_no_names(self, adata: AnnData, path: Path, lin_key: str, n_lins: int): names_key = _lin_names(lin_key) del adata.uns[names_key] sc.write(path, adata) adata_new = cr.read(path) lins = adata_new.obsm[lin_key] assert isinstance(lins, Lineage) np.testing.assert_array_equal( lins.names, [f"Lineage {i}" for i in range(n_lins)] ) np.testing.assert_array_equal(lins.names, adata_new.uns[names_key])
def test_normal_run(self, adata: AnnData, path: Path, lin_key: str, n_lins: int): colors = _create_categorical_colors(10)[-n_lins:] names = [f"foo {i}" for i in range(n_lins)] adata.uns[_colors(lin_key)] = colors adata.uns[_lin_names(lin_key)] = names sc.write(path, adata) adata_new = cr.read(path) lins_new = adata_new.obsm[lin_key] np.testing.assert_array_equal(lins_new.colors, colors) np.testing.assert_array_equal(lins_new.names, names)
def maybe_create_lineage(direction: Direction): lin_key = str(LinKey.FORWARD if direction == Direction.FORWARD else LinKey.BACKWARD) names_key, colors_key = _lin_names(lin_key), _colors(lin_key) if lin_key in adata.obsm.keys(): n_cells, n_lineages = adata.obsm[lin_key].shape logg.info( f"Creating {'forward' if direction == Direction.FORWARD else 'backward'} `Lineage` object" ) if names_key not in adata.uns.keys(): logg.warning( f"Lineage names not found in `adata.uns[{names_key!r}]`, creating dummy names" ) names = [f"Lineage {i}" for i in range(n_lineages)] elif len(adata.uns[names_key]) != n_lineages: logg.warning( f"Lineage names are don't have the required length ({n_lineages}), creating dummy names" ) names = [f"Lineage {i}" for i in range(n_lineages)] else: logg.info("Succesfully loaded names") names = adata.uns[names_key] if colors_key not in adata.uns.keys(): logg.warning( f"Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors" ) colors = _create_categorical_colors(n_lineages) elif len(adata.uns[colors_key]) != n_lineages or not all( map(lambda c: is_color_like(c), adata.uns[colors_key])): logg.warning( f"Lineage colors don't have the required length ({n_lineages}) " f"or are not color-like, creating new colors") colors = _create_categorical_colors(n_lineages) else: logg.info("Succesfully loaded colors") colors = adata.uns[colors_key] adata.obsm[lin_key] = Lineage(adata.obsm[lin_key], names=names, colors=colors) adata.uns[colors_key] = colors adata.uns[names_key] = names else: logg.debug( f"DEBUG: Unable to load {'forward' if direction == Direction.FORWARD else 'backward'} " f"`Lineage` from `adata.obsm[{lin_key!r}]`")
def test_wrong_names_length( self, adata: AnnData, path: Path, lin_key: str, n_lins: int ): names_key = _lin_names(lin_key) adata.uns[names_key] = list(adata.uns[names_key]) adata.uns[names_key] += ["foo", "bar", "baz"] sc.write(path, adata) adata_new = cr.read(path) lins = adata_new.obsm[lin_key] assert isinstance(lins, Lineage) np.testing.assert_array_equal( lins.names, [f"Lineage {i}" for i in range(n_lins)] ) np.testing.assert_array_equal(lins.names, adata_new.uns[names_key])
def _read_from_adata( self, g2m_key: Optional[str] = None, s_key: Optional[str] = None, **kwargs ) -> None: if f"eig_{self._direction}" in self._adata.uns.keys(): self._eig = self._adata.uns[f"eig_{self._direction}"] else: logg.debug( f"DEBUG: `eig_{self._direction}` not found. Setting `.eig` to `None`" ) if self._rc_key in self._adata.obs.keys(): self._meta_states = self._adata.obs[self._rc_key] else: logg.debug( f"DEBUG: `{self._rc_key}` not found in `adata.obs`. Setting `.metastable_states` to `None`" ) if _colors(self._rc_key) in self._adata.uns.keys(): self._meta_states_colors = self._adata.uns[_colors(self._rc_key)] else: logg.debug( f"DEBUG: `{_colors(self._rc_key)}` not found in `adata.uns`. " f"Setting `.metastable_states_colors`to `None`" ) if self._lin_key in self._adata.obsm.keys(): lineages = range(self._adata.obsm[self._lin_key].shape[1]) colors = _create_categorical_colors(len(lineages)) self._lin_probs = Lineage( self._adata.obsm[self._lin_key], names=[f"Lineage {i + 1}" for i in lineages], colors=colors, ) self._adata.obsm[self._lin_key] = self._lin_probs else: logg.debug( f"DEBUG: `{self._lin_key}` not found in `adata.obsm`. Setting `.lin_probs` to `None`" ) if _dp(self._lin_key) in self._adata.obs.keys(): self._dp = self._adata.obs[_dp(self._lin_key)] else: logg.debug( f"DEBUG: `{_dp(self._lin_key)}` not found in `adata.obs`. Setting `.diff_potential` to `None`" ) if g2m_key and g2m_key in self._adata.obs.keys(): self._G2M_score = self._adata.obs[g2m_key] else: logg.debug( f"DEBUG: `{g2m_key}` not found in `adata.obs`. Setting `.G2M_score` to `None`" ) if s_key and s_key in self._adata.obs.keys(): self._S_score = self._adata.obs[s_key] else: logg.debug( f"DEBUG: `{s_key}` not found in `adata.obs`. Setting `.S_score` to `None`" ) if _probs(self._rc_key) in self._adata.obs.keys(): self._meta_states_probs = self._adata.obs[_probs(self._rc_key)] else: logg.debug( f"DEBUG: `{_probs(self._rc_key)}` not found in `adata.obs`. " f"Setting `.metastable_states_probs` to `None`" ) if self._lin_probs is not None: if _lin_names(self._lin_key) in self._adata.uns.keys(): self._lin_probs = Lineage( np.array(self._lin_probs), names=self._adata.uns[_lin_names(self._lin_key)], colors=self._lin_probs.colors, ) self._adata.obsm[self._lin_key] = self._lin_probs else: logg.debug( f"DEBUG: `{_lin_names(self._lin_key)}` not found in `adata.uns`. " f"Using default names" ) if _colors(self._lin_key) in self._adata.uns.keys(): self._lin_probs = Lineage( np.array(self._lin_probs), names=self._lin_probs.names, colors=self._adata.uns[_colors(self._lin_key)], ) self._adata.obsm[self._lin_key] = self._lin_probs else: logg.debug( f"DEBUG: `{_colors(self._lin_key)}` not found in `adata.uns`. " f"Using default colors" )
def compute_lin_probs( self, keys: Optional[Sequence[str]] = None, check_irred: bool = False, norm_by_frequ: bool = False, ) -> None: """ Compute absorption probabilities for a Markov chain. For each cell, this computes the probability of it reaching any of the approximate recurrent classes. This also computes the entropy over absorption probabilities, which is a measure of cell plasticity, see [Setty19]_. Params ------ keys Comma separated sequence of keys defining the recurrent classes. check_irred Check whether the matrix restricted to the given transient states is irreducible. norm_by_frequ Divide absorption probabilities for `rc_i` by `|rc_i|`. Returns ------- None Nothing, but updates the following fields: :paramref:`lineage_probabilities`, :paramref:`diff_potential`. """ if self._meta_states is None: raise RuntimeError( "Compute approximate recurrent classes first as `.compute_metastable_states()`" ) if keys is not None: keys = sorted(set(keys)) # Note: There are three relevant data structures here # - self.metastable_states: pd.Series which contains annotations for approx rcs. Associated colors in # self.metastable_states_colors # - self.lin_probs: Linage object which contains the lineage probabilities with associated names and colors # -_metastable_states: pd.Series, temporary copy of self.approx rcs used in the context of this function. # In this copy, some metastable_states may be removed or combined with others start = logg.info("Computing absorption probabilities") # we don't expect the abs. probs. to be sparse, therefore, make T dense. See scipy docs about sparse lin solve. t = self._T.A if self._is_sparse else self._T # colors are created in `compute_metastable_states`, this is just in case self._check_and_create_colors() # process the current annotations according to `keys` metastable_states_, colors_ = _process_series( series=self._meta_states, keys=keys, colors=self._meta_states_colors ) # create empty lineage object if self._lin_probs is not None: logg.debug("DEBUG: Overwriting `.lin_probs`") self._lin_probs = Lineage( np.empty((1, len(colors_))), names=metastable_states_.cat.categories, colors=colors_, ) # warn in case only one state is left keys = list(metastable_states_.cat.categories) if len(keys) == 1: logg.warning( "There is only one recurrent class, all cells will have probability 1 of going there" ) # create arrays of all recurrent and transient indices mask = np.repeat(False, len(metastable_states_)) for cat in metastable_states_.cat.categories: mask = np.logical_or(mask, metastable_states_ == cat) rec_indices, trans_indices = np.where(mask)[0], np.where(~mask)[0] # create Q (restriction transient-transient), S (restriction transient-recurrent) and I (Q-sized identity) q = t[trans_indices, :][:, trans_indices] s = t[trans_indices, :][:, rec_indices] eye = np.eye(len(trans_indices)) if check_irred: if self._is_irreducible is None: self.compute_partition() if not self._is_irreducible: logg.warning("Restriction Q is not irreducible") # compute abs probs. Since we don't expect sparse solution, dense computation is faster. logg.debug("DEBUG: Solving the linear system to find absorption probabilities") abs_states = solve(eye - q, s) # aggregate to class level by summing over columns belonging to the same metastable_states approx_rc_red = metastable_states_[mask] rec_classes_red = { key: np.where(approx_rc_red == key)[0] for key in approx_rc_red.cat.categories } _abs_classes = np.concatenate( [ np.sum(abs_states[:, rec_classes_red[key]], axis=1)[:, None] for key in approx_rc_red.cat.categories ], axis=1, ) if norm_by_frequ: logg.debug("DEBUG: Normalizing by frequency") _abs_classes /= [len(value) for value in rec_classes_red.values()] _abs_classes = _normalize(_abs_classes) # for recurrent states, set their self-absorption probability to one abs_classes = np.zeros((self._n_states, len(rec_classes_red))) rec_classes_full = { cl: np.where(metastable_states_ == cl) for cl in metastable_states_.cat.categories } for col, cl_indices in enumerate(rec_classes_full.values()): abs_classes[trans_indices, col] = _abs_classes[:, col] abs_classes[cl_indices, col] = 1 self._dp = entropy(abs_classes.T) self._lin_probs = Lineage( abs_classes, names=list(self._lin_probs.names), colors=list(self._lin_probs.colors), ) self._adata.obsm[self._lin_key] = self._lin_probs self._adata.obs[_dp(self._lin_key)] = self._dp self._adata.uns[_lin_names(self._lin_key)] = self._lin_probs.names self._adata.uns[_colors(self._lin_key)] = self._lin_probs.colors logg.info(" Finish", time=start)