def test_empty_keys(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") res = _process_series(x, []) assert res.shape == x.shape assert np.all(pd.isnull(res))
def test_repeat_key(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") expected = pd.Series(["a"] + [np.nan] * 4).astype("category") res = _process_series(x, keys=["a, a, a"]) assert_array_nan_equal(res, expected)
def test_normal_run(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") expected = pd.Series(["a"] + [np.nan] * 4).astype("category") res = _process_series(x, keys=["a"]) assert_array_nan_equal(expected, res)
def test_no_keys_colors(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") colors = ["foo"] res, res_colors = _process_series(x, keys=None, colors=colors) assert x is res assert colors is res_colors
def test_reoder_keys(self): x = pd.Series(["b", "c", "a", "d", "a"]).astype("category") expected = pd.Series(["a or b or d", np.nan] + ["a or b or d"] * 3).astype( "category" ) res = _process_series(x, keys=["b, a, d"]) assert_array_nan_equal(res, expected)
def test_return_colors(self): x = pd.Series(["b", "c", "a", "d", "a"]).astype("category") expected = pd.Series(["a or b", "c or d", "a or b", "c or d", "a or b"]).astype( "category" ) res, colors = _process_series( x, keys=["b, a", "d, c"], colors=["red", "green", "blue", "white"] ) assert isinstance(res, pd.Series) assert is_categorical_dtype(res) assert isinstance(colors, list) np.testing.assert_array_equal(res.values, expected.values) assert set(colors) == {"#804000", "#8080ff"}
def test_no_keys(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") res = _process_series(x, keys=None) assert x is res
def test_keys_overlap(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") with pytest.raises(ValueError): _ = _process_series(x, ["a", "b, a"])
def test_keys_are_not_proper_categories(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") with pytest.raises(ValueError): _ = _process_series(x, ["foo"])
def test_colors_not_colorlike(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") with pytest.raises(ValueError): _ = _process_series(x, ["foo"], colors=["bar"])
def test_colors_wrong_number_of_colors(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]).astype("category") with pytest.raises(ValueError): _ = _process_series(x, ["foo"], colors=["red"])
def test_not_categorical(self): x = pd.Series(["a", "b", np.nan, "b", np.nan]) with pytest.raises(TypeError): _ = _process_series(x, ["foo"])
def compute_lin_probs( self, keys: Optional[Sequence[str]] = None, check_irred: bool = False, norm_by_frequ: bool = False, ) -> None: """ Compute absorption probabilities for a Markov chain. For each cell, this computes the probability of it reaching any of the approximate recurrent classes. This also computes the entropy over absorption probabilities, which is a measure of cell plasticity, see [Setty19]_. Params ------ keys Comma separated sequence of keys defining the recurrent classes. check_irred Check whether the matrix restricted to the given transient states is irreducible. norm_by_frequ Divide absorption probabilities for `rc_i` by `|rc_i|`. Returns ------- None Nothing, but updates the following fields: :paramref:`lineage_probabilities`, :paramref:`diff_potential`. """ if self._meta_states is None: raise RuntimeError( "Compute approximate recurrent classes first as `.compute_metastable_states()`" ) if keys is not None: keys = sorted(set(keys)) # Note: There are three relevant data structures here # - self.metastable_states: pd.Series which contains annotations for approx rcs. Associated colors in # self.metastable_states_colors # - self.lin_probs: Linage object which contains the lineage probabilities with associated names and colors # -_metastable_states: pd.Series, temporary copy of self.approx rcs used in the context of this function. # In this copy, some metastable_states may be removed or combined with others start = logg.info("Computing absorption probabilities") # we don't expect the abs. probs. to be sparse, therefore, make T dense. See scipy docs about sparse lin solve. t = self._T.A if self._is_sparse else self._T # colors are created in `compute_metastable_states`, this is just in case self._check_and_create_colors() # process the current annotations according to `keys` metastable_states_, colors_ = _process_series( series=self._meta_states, keys=keys, colors=self._meta_states_colors ) # create empty lineage object if self._lin_probs is not None: logg.debug("DEBUG: Overwriting `.lin_probs`") self._lin_probs = Lineage( np.empty((1, len(colors_))), names=metastable_states_.cat.categories, colors=colors_, ) # warn in case only one state is left keys = list(metastable_states_.cat.categories) if len(keys) == 1: logg.warning( "There is only one recurrent class, all cells will have probability 1 of going there" ) # create arrays of all recurrent and transient indices mask = np.repeat(False, len(metastable_states_)) for cat in metastable_states_.cat.categories: mask = np.logical_or(mask, metastable_states_ == cat) rec_indices, trans_indices = np.where(mask)[0], np.where(~mask)[0] # create Q (restriction transient-transient), S (restriction transient-recurrent) and I (Q-sized identity) q = t[trans_indices, :][:, trans_indices] s = t[trans_indices, :][:, rec_indices] eye = np.eye(len(trans_indices)) if check_irred: if self._is_irreducible is None: self.compute_partition() if not self._is_irreducible: logg.warning("Restriction Q is not irreducible") # compute abs probs. Since we don't expect sparse solution, dense computation is faster. logg.debug("DEBUG: Solving the linear system to find absorption probabilities") abs_states = solve(eye - q, s) # aggregate to class level by summing over columns belonging to the same metastable_states approx_rc_red = metastable_states_[mask] rec_classes_red = { key: np.where(approx_rc_red == key)[0] for key in approx_rc_red.cat.categories } _abs_classes = np.concatenate( [ np.sum(abs_states[:, rec_classes_red[key]], axis=1)[:, None] for key in approx_rc_red.cat.categories ], axis=1, ) if norm_by_frequ: logg.debug("DEBUG: Normalizing by frequency") _abs_classes /= [len(value) for value in rec_classes_red.values()] _abs_classes = _normalize(_abs_classes) # for recurrent states, set their self-absorption probability to one abs_classes = np.zeros((self._n_states, len(rec_classes_red))) rec_classes_full = { cl: np.where(metastable_states_ == cl) for cl in metastable_states_.cat.categories } for col, cl_indices in enumerate(rec_classes_full.values()): abs_classes[trans_indices, col] = _abs_classes[:, col] abs_classes[cl_indices, col] = 1 self._dp = entropy(abs_classes.T) self._lin_probs = Lineage( abs_classes, names=list(self._lin_probs.names), colors=list(self._lin_probs.colors), ) self._adata.obsm[self._lin_key] = self._lin_probs self._adata.obs[_dp(self._lin_key)] = self._dp self._adata.uns[_lin_names(self._lin_key)] = self._lin_probs.names self._adata.uns[_colors(self._lin_key)] = self._lin_probs.colors logg.info(" Finish", time=start)