def test_not_one_hot(self): a = np.array( [[1, 0, 1], [0, 0, 1], [0, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], dtype="bool", ) with pytest.raises(ValueError): _series_from_one_hot_matrix(a)
def test_dtype(self): a = np.array( [[0, 0, 1], [0, 0, 2], [0, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], dtype="int8", ) with pytest.raises(TypeError): _series_from_one_hot_matrix(a)
def test_name_mismatch(self): a = np.array( [[0, 0, 1], [0, 0, 1], [0, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], dtype="bool", ) names = ["0", "1"] with pytest.raises(ValueError): _series_from_one_hot_matrix(a, names=names)
def _create_states( self, probs: Union[np.ndarray, Lineage], n_cells: int, check_row_sums: bool = False, return_not_enough_cells: bool = False, ) -> pd.Series: if n_cells <= 0: raise ValueError( f"Expected `n_cells` to be positive, found `{n_cells}`.") a_discrete, not_enough_cells = _fuzzy_to_discrete( a_fuzzy=probs, n_most_likely=n_cells, remove_overlap=False, raise_threshold=0.2, check_row_sums=check_row_sums, ) states = _series_from_one_hot_matrix( membership=a_discrete, index=self.adata.obs_names, names=probs.names if isinstance(probs, Lineage) else None, ) return (states, not_enough_cells) if return_not_enough_cells else states
def test_normal_run(self): a = np.array( [[0, 0, 1], [0, 0, 1], [0, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], dtype="bool", ) res = _series_from_one_hot_matrix(a) assert_array_nan_equal( np.array(res).astype(np.float32), np.array([2, 2, np.nan, 0, 0, 1], dtype=np.float32), ) np.testing.assert_array_equal(res.cat.categories, ["0", "1", "2"])
def test_normal_return(self): a = np.array( [[0, 0, 1], [0, 0, 1], [0, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], dtype="bool", ) actual_series = _series_from_one_hot_matrix(a) expected_series = pd.Series(index=range(6), dtype="category") expected_series = expected_series.cat.add_categories(["0", "1", "2"]) expected_series[0] = "2" expected_series[1] = "2" expected_series[3] = "0" expected_series[4] = "0" expected_series[5] = "1" assert actual_series.equals(expected_series) assert (actual_series.cat.categories == expected_series.cat.categories).all()