def test_compute_lin_probs_normal_run(self, adata_large: AnnData):
        vk = VelocityKernel(adata_large).compute_transition_matrix()
        ck = ConnectivityKernel(adata_large).compute_transition_matrix()
        final_kernel = 0.8 * vk + 0.2 * ck

        mc = cr.tl.MarkovChain(final_kernel)
        mc.compute_eig(k=5)
        mc.compute_approx_rcs(use=2)
        mc.compute_lin_probs()

        assert isinstance(mc.diff_potential, np.ndarray)
        assert f"{LinKey.FORWARD}_dp" in mc.adata.obs.keys()
        np.testing.assert_array_equal(mc.diff_potential,
                                      mc.adata.obs[f"{LinKey.FORWARD}_dp"])

        assert isinstance(mc.lineage_probabilities, cr.tl.Lineage)
        assert mc.lineage_probabilities.shape == (mc.adata.n_obs, 2)
        assert f"{LinKey.FORWARD}" in mc.adata.obsm.keys()
        np.testing.assert_array_equal(mc.lineage_probabilities.X,
                                      mc.adata.obsm[f"{LinKey.FORWARD}"])

        assert _lin_names(LinKey.FORWARD) in mc.adata.uns.keys()
        np.testing.assert_array_equal(mc.lineage_probabilities.names,
                                      mc.adata.uns[_lin_names(LinKey.FORWARD)])

        assert _colors(LinKey.FORWARD) in mc.adata.uns.keys()
        np.testing.assert_array_equal(mc.lineage_probabilities.colors,
                                      mc.adata.uns[_colors(LinKey.FORWARD)])
        np.testing.assert_allclose(mc.lineage_probabilities.X.sum(1), 1)
Exemple #2
0
def _assert_has_all_keys(adata: AnnData, direction: Direction):
    assert _transition(direction) in adata.uns.keys()

    if direction == Direction.FORWARD:
        assert str(LinKey.FORWARD) in adata.obsm
        assert isinstance(adata.obsm[str(LinKey.FORWARD)], cr.tl.Lineage)

        assert _colors(LinKey.FORWARD) in adata.uns.keys()
        assert _lin_names(LinKey.FORWARD) in adata.uns.keys()

        assert str(StateKey.FORWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(StateKey.FORWARD)])

        assert _probs(StateKey.FORWARD) in adata.obs
    else:
        assert str(LinKey.BACKWARD) in adata.obsm
        assert isinstance(adata.obsm[str(LinKey.BACKWARD)], cr.tl.Lineage)

        assert _colors(LinKey.BACKWARD) in adata.uns.keys()
        assert _lin_names(LinKey.BACKWARD) in adata.uns.keys()

        assert str(StateKey.BACKWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(StateKey.BACKWARD)])

        assert _probs(StateKey.BACKWARD) in adata.obs
Exemple #3
0
def _check_main_states(mc: cr.tl.GPCCA, has_main_states: bool = True):
    if has_main_states:
        assert isinstance(mc.main_states, pd.Series)
        assert_array_nan_equal(mc.adata.obs[str(StateKey.FORWARD)], mc.main_states)
        np.testing.assert_array_equal(
            mc.adata.uns[_colors(StateKey.FORWARD)],
            mc.lineage_probabilities[list(mc.main_states.cat.categories)].colors,
        )

    assert isinstance(mc.diff_potential, np.ndarray)
    assert isinstance(mc.lineage_probabilities, cr.tl.Lineage)

    np.testing.assert_array_equal(
        mc.adata.obsm[str(LinKey.FORWARD)], mc.lineage_probabilities.X
    )
    np.testing.assert_array_equal(
        mc.adata.uns[_lin_names(LinKey.FORWARD)], mc.lineage_probabilities.names
    )
    np.testing.assert_array_equal(
        mc.adata.uns[_colors(LinKey.FORWARD)], mc.lineage_probabilities.colors
    )

    np.testing.assert_array_equal(mc.adata.obs[_dp(LinKey.FORWARD)], mc.diff_potential)
    np.testing.assert_array_equal(
        mc.adata.obs[_probs(StateKey.FORWARD)], mc.main_states_probabilities
    )
Exemple #4
0
    def test_non_unique_names(self, adata: AnnData, path: Path, lin_key: str, _: int):
        names_key = _lin_names(lin_key)
        adata.uns[names_key][0] = adata.uns[names_key][1]

        sc.write(path, adata)
        with pytest.raises(ValueError):
            _ = cr.read(path)
Exemple #5
0
    def test_no_names(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        names_key = _lin_names(lin_key)
        del adata.uns[names_key]

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins = adata_new.obsm[lin_key]

        assert isinstance(lins, Lineage)
        np.testing.assert_array_equal(
            lins.names, [f"Lineage {i}" for i in range(n_lins)]
        )
        np.testing.assert_array_equal(lins.names, adata_new.uns[names_key])
Exemple #6
0
    def test_normal_run(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors = _create_categorical_colors(10)[-n_lins:]
        names = [f"foo {i}" for i in range(n_lins)]

        adata.uns[_colors(lin_key)] = colors
        adata.uns[_lin_names(lin_key)] = names

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins_new = adata_new.obsm[lin_key]

        np.testing.assert_array_equal(lins_new.colors, colors)
        np.testing.assert_array_equal(lins_new.names, names)
Exemple #7
0
    def maybe_create_lineage(direction: Direction):
        lin_key = str(LinKey.FORWARD if direction ==
                      Direction.FORWARD else LinKey.BACKWARD)
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)
        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(
                f"Creating {'forward' if direction == Direction.FORWARD else 'backward'} `Lineage` object"
            )

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"Lineage names not found in `adata.uns[{names_key!r}]`, creating dummy names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"Lineage names are don't have the required length ({n_lineages}), creating dummy names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("Succesfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                    map(lambda c: is_color_like(c), adata.uns[colors_key])):
                logg.warning(
                    f"Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors")
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("Succesfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(adata.obsm[lin_key],
                                          names=names,
                                          colors=colors)
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"DEBUG: Unable to load {'forward' if direction == Direction.FORWARD else 'backward'} "
                f"`Lineage` from `adata.obsm[{lin_key!r}]`")
Exemple #8
0
    def test_wrong_names_length(
        self, adata: AnnData, path: Path, lin_key: str, n_lins: int
    ):
        names_key = _lin_names(lin_key)
        adata.uns[names_key] = list(adata.uns[names_key])
        adata.uns[names_key] += ["foo", "bar", "baz"]

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins = adata_new.obsm[lin_key]

        assert isinstance(lins, Lineage)
        np.testing.assert_array_equal(
            lins.names, [f"Lineage {i}" for i in range(n_lins)]
        )
        np.testing.assert_array_equal(lins.names, adata_new.uns[names_key])
Exemple #9
0
    def _read_from_adata(
        self, g2m_key: Optional[str] = None, s_key: Optional[str] = None, **kwargs
    ) -> None:
        if f"eig_{self._direction}" in self._adata.uns.keys():
            self._eig = self._adata.uns[f"eig_{self._direction}"]
        else:
            logg.debug(
                f"DEBUG: `eig_{self._direction}` not found. Setting `.eig` to `None`"
            )

        if self._rc_key in self._adata.obs.keys():
            self._meta_states = self._adata.obs[self._rc_key]
        else:
            logg.debug(
                f"DEBUG: `{self._rc_key}` not found in `adata.obs`. Setting `.metastable_states` to `None`"
            )

        if _colors(self._rc_key) in self._adata.uns.keys():
            self._meta_states_colors = self._adata.uns[_colors(self._rc_key)]
        else:
            logg.debug(
                f"DEBUG: `{_colors(self._rc_key)}` not found in `adata.uns`. "
                f"Setting `.metastable_states_colors`to `None`"
            )

        if self._lin_key in self._adata.obsm.keys():
            lineages = range(self._adata.obsm[self._lin_key].shape[1])
            colors = _create_categorical_colors(len(lineages))
            self._lin_probs = Lineage(
                self._adata.obsm[self._lin_key],
                names=[f"Lineage {i + 1}" for i in lineages],
                colors=colors,
            )
            self._adata.obsm[self._lin_key] = self._lin_probs
        else:
            logg.debug(
                f"DEBUG: `{self._lin_key}` not found in `adata.obsm`. Setting `.lin_probs` to `None`"
            )

        if _dp(self._lin_key) in self._adata.obs.keys():
            self._dp = self._adata.obs[_dp(self._lin_key)]
        else:
            logg.debug(
                f"DEBUG: `{_dp(self._lin_key)}` not found in `adata.obs`. Setting `.diff_potential` to `None`"
            )

        if g2m_key and g2m_key in self._adata.obs.keys():
            self._G2M_score = self._adata.obs[g2m_key]
        else:
            logg.debug(
                f"DEBUG: `{g2m_key}` not found in `adata.obs`. Setting `.G2M_score` to `None`"
            )

        if s_key and s_key in self._adata.obs.keys():
            self._S_score = self._adata.obs[s_key]
        else:
            logg.debug(
                f"DEBUG: `{s_key}` not found in `adata.obs`. Setting `.S_score` to `None`"
            )

        if _probs(self._rc_key) in self._adata.obs.keys():
            self._meta_states_probs = self._adata.obs[_probs(self._rc_key)]
        else:
            logg.debug(
                f"DEBUG: `{_probs(self._rc_key)}` not found in `adata.obs`. "
                f"Setting `.metastable_states_probs` to `None`"
            )

        if self._lin_probs is not None:
            if _lin_names(self._lin_key) in self._adata.uns.keys():
                self._lin_probs = Lineage(
                    np.array(self._lin_probs),
                    names=self._adata.uns[_lin_names(self._lin_key)],
                    colors=self._lin_probs.colors,
                )
                self._adata.obsm[self._lin_key] = self._lin_probs
            else:
                logg.debug(
                    f"DEBUG: `{_lin_names(self._lin_key)}` not found in `adata.uns`. "
                    f"Using default names"
                )

            if _colors(self._lin_key) in self._adata.uns.keys():
                self._lin_probs = Lineage(
                    np.array(self._lin_probs),
                    names=self._lin_probs.names,
                    colors=self._adata.uns[_colors(self._lin_key)],
                )
                self._adata.obsm[self._lin_key] = self._lin_probs
            else:
                logg.debug(
                    f"DEBUG: `{_colors(self._lin_key)}` not found in `adata.uns`. "
                    f"Using default colors"
                )
Exemple #10
0
    def compute_lin_probs(
        self,
        keys: Optional[Sequence[str]] = None,
        check_irred: bool = False,
        norm_by_frequ: bool = False,
    ) -> None:
        """
        Compute absorption probabilities for a Markov chain.

        For each cell, this computes the probability of it reaching any of the approximate recurrent classes.
        This also computes the entropy over absorption probabilities, which is a measure of cell plasticity, see
        [Setty19]_.

        Params
        ------
        keys
            Comma separated sequence of keys defining the recurrent classes.
        check_irred
            Check whether the matrix restricted to the given transient states is irreducible.
        norm_by_frequ
            Divide absorption probabilities for `rc_i` by `|rc_i|`.

        Returns
        -------
        None
            Nothing, but updates the following fields: :paramref:`lineage_probabilities`, :paramref:`diff_potential`.
        """

        if self._meta_states is None:
            raise RuntimeError(
                "Compute approximate recurrent classes first as `.compute_metastable_states()`"
            )
        if keys is not None:
            keys = sorted(set(keys))

        # Note: There are three relevant data structures here
        # - self.metastable_states: pd.Series which contains annotations for approx rcs. Associated colors in
        #   self.metastable_states_colors
        # - self.lin_probs: Linage object which contains the lineage probabilities with associated names and colors
        # -_metastable_states: pd.Series, temporary copy of self.approx rcs used in the context of this function.
        #   In this copy, some metastable_states may be removed or combined with others
        start = logg.info("Computing absorption probabilities")

        # we don't expect the abs. probs. to be sparse, therefore, make T dense. See scipy docs about sparse lin solve.
        t = self._T.A if self._is_sparse else self._T

        # colors are created in `compute_metastable_states`, this is just in case
        self._check_and_create_colors()

        # process the current annotations according to `keys`
        metastable_states_, colors_ = _process_series(
            series=self._meta_states, keys=keys, colors=self._meta_states_colors
        )

        #  create empty lineage object
        if self._lin_probs is not None:
            logg.debug("DEBUG: Overwriting `.lin_probs`")
        self._lin_probs = Lineage(
            np.empty((1, len(colors_))),
            names=metastable_states_.cat.categories,
            colors=colors_,
        )

        # warn in case only one state is left
        keys = list(metastable_states_.cat.categories)
        if len(keys) == 1:
            logg.warning(
                "There is only one recurrent class, all cells will have probability 1 of going there"
            )

        # create arrays of all recurrent and transient indices
        mask = np.repeat(False, len(metastable_states_))
        for cat in metastable_states_.cat.categories:
            mask = np.logical_or(mask, metastable_states_ == cat)
        rec_indices, trans_indices = np.where(mask)[0], np.where(~mask)[0]

        # create Q (restriction transient-transient), S (restriction transient-recurrent) and I (Q-sized identity)
        q = t[trans_indices, :][:, trans_indices]
        s = t[trans_indices, :][:, rec_indices]
        eye = np.eye(len(trans_indices))

        if check_irred:
            if self._is_irreducible is None:
                self.compute_partition()
            if not self._is_irreducible:
                logg.warning("Restriction Q is not irreducible")

        # compute abs probs. Since we don't expect sparse solution, dense computation is faster.
        logg.debug("DEBUG: Solving the linear system to find absorption probabilities")
        abs_states = solve(eye - q, s)

        # aggregate to class level by summing over columns belonging to the same metastable_states
        approx_rc_red = metastable_states_[mask]
        rec_classes_red = {
            key: np.where(approx_rc_red == key)[0]
            for key in approx_rc_red.cat.categories
        }
        _abs_classes = np.concatenate(
            [
                np.sum(abs_states[:, rec_classes_red[key]], axis=1)[:, None]
                for key in approx_rc_red.cat.categories
            ],
            axis=1,
        )

        if norm_by_frequ:
            logg.debug("DEBUG: Normalizing by frequency")
            _abs_classes /= [len(value) for value in rec_classes_red.values()]
        _abs_classes = _normalize(_abs_classes)

        # for recurrent states, set their self-absorption probability to one
        abs_classes = np.zeros((self._n_states, len(rec_classes_red)))
        rec_classes_full = {
            cl: np.where(metastable_states_ == cl)
            for cl in metastable_states_.cat.categories
        }
        for col, cl_indices in enumerate(rec_classes_full.values()):
            abs_classes[trans_indices, col] = _abs_classes[:, col]
            abs_classes[cl_indices, col] = 1

        self._dp = entropy(abs_classes.T)
        self._lin_probs = Lineage(
            abs_classes,
            names=list(self._lin_probs.names),
            colors=list(self._lin_probs.colors),
        )

        self._adata.obsm[self._lin_key] = self._lin_probs
        self._adata.obs[_dp(self._lin_key)] = self._dp
        self._adata.uns[_lin_names(self._lin_key)] = self._lin_probs.names
        self._adata.uns[_colors(self._lin_key)] = self._lin_probs.colors

        logg.info("    Finish", time=start)