Ejemplo n.º 1
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):
        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                if isinstance(probs, Lineage):
                    names = probs.names
                else:
                    logg.warning(
                        f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new names"
                    )
                    names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                map(lambda c: isinstance(c, str) and is_color_like(c), colors)
            ):
                if isinstance(probs, Lineage):
                    colors = probs.colors
                else:
                    logg.warning(
                        f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new colors"
                    )
                    colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Ejemplo n.º 2
0
def _check_abs_probs(mc: cr.tl.estimators.GPCCA, has_main_states: bool = True):
    if has_main_states:
        assert isinstance(mc._get(P.FIN), pd.Series)
        assert_array_nan_equal(mc.adata.obs[str(FinalStatesKey.FORWARD)],
                               mc._get(P.FIN))
        np.testing.assert_array_equal(
            mc.adata.uns[_colors(FinalStatesKey.FORWARD)],
            mc._get(A.FIN_ABS_PROBS)[list(mc._get(
                P.FIN).cat.categories)].colors,
        )

    assert isinstance(mc._get(P.DIFF_POT), pd.Series)
    assert isinstance(mc._get(P.ABS_PROBS), cr.tl.Lineage)
    np.testing.assert_array_almost_equal(mc._get(P.ABS_PROBS).sum(1), 1.0)

    np.testing.assert_array_equal(mc.adata.obsm[str(AbsProbKey.FORWARD)],
                                  mc._get(P.ABS_PROBS).X)
    np.testing.assert_array_equal(mc.adata.uns[_lin_names(AbsProbKey.FORWARD)],
                                  mc._get(P.ABS_PROBS).names)
    np.testing.assert_array_equal(mc.adata.uns[_colors(AbsProbKey.FORWARD)],
                                  mc._get(P.ABS_PROBS).colors)

    np.testing.assert_array_equal(mc.adata.obs[_dp(AbsProbKey.FORWARD)],
                                  mc._get(P.DIFF_POT))

    assert_array_nan_equal(mc.adata.obs[FinalStatesKey.FORWARD.s],
                           mc._get(P.FIN))
    np.testing.assert_array_equal(mc.adata.obs[_probs(FinalStatesKey.FORWARD)],
                                  mc._get(P.FIN_PROBS))
Ejemplo n.º 3
0
    def test_compute_absorption_probabilities_normal_run(
            self, adata_large: AnnData):
        vk = VelocityKernel(adata_large).compute_transition_matrix(
            softmax_scale=4)
        ck = ConnectivityKernel(adata_large).compute_transition_matrix()
        final_kernel = 0.8 * vk + 0.2 * ck

        mc = cr.tl.estimators.CFLARE(final_kernel)
        mc.compute_eigendecomposition(k=5)
        mc.compute_final_states(use=2)
        mc.compute_absorption_probabilities()

        assert isinstance(mc._get(P.DIFF_POT), pd.Series)
        assert f"{AbsProbKey.FORWARD}_dp" in mc.adata.obs.keys()
        np.testing.assert_array_equal(mc._get(P.DIFF_POT),
                                      mc.adata.obs[f"{AbsProbKey.FORWARD}_dp"])

        assert isinstance(mc._get(P.ABS_PROBS), cr.tl.Lineage)
        assert mc._get(P.ABS_PROBS).shape == (mc.adata.n_obs, 2)
        assert f"{AbsProbKey.FORWARD}" in mc.adata.obsm.keys()
        np.testing.assert_array_equal(
            mc._get(P.ABS_PROBS).X, mc.adata.obsm[f"{AbsProbKey.FORWARD}"])

        assert _lin_names(AbsProbKey.FORWARD) in mc.adata.uns.keys()
        np.testing.assert_array_equal(
            mc._get(P.ABS_PROBS).names,
            mc.adata.uns[_lin_names(AbsProbKey.FORWARD)],
        )

        assert _colors(AbsProbKey.FORWARD) in mc.adata.uns.keys()
        np.testing.assert_array_equal(
            mc._get(P.ABS_PROBS).colors,
            mc.adata.uns[_colors(AbsProbKey.FORWARD)],
        )
        np.testing.assert_allclose(mc._get(P.ABS_PROBS).X.sum(1), 1)
Ejemplo n.º 4
0
def _check_abs_probs(mc: cr.tl.estimators.GPCCA, has_main_states: bool = True):
    if has_main_states:
        assert isinstance(mc._get(P.TERM), pd.Series)
        assert_array_nan_equal(mc.adata.obs[str(TermStatesKey.FORWARD)],
                               mc._get(P.TERM))
        np.testing.assert_array_equal(
            mc.adata.uns[_colors(TermStatesKey.FORWARD)],
            mc._get(A.TERM_ABS_PROBS)[list(mc._get(
                P.TERM).cat.categories)].colors,
        )

    assert isinstance(mc._get(P.PRIME_DEG), pd.Series)
    assert isinstance(mc._get(P.ABS_PROBS), cr.tl.Lineage)
    np.testing.assert_array_almost_equal(mc._get(P.ABS_PROBS).sum(1), 1.0)

    np.testing.assert_array_equal(mc.adata.obsm[str(AbsProbKey.FORWARD)],
                                  mc._get(P.ABS_PROBS).X)
    np.testing.assert_array_equal(mc.adata.uns[_lin_names(AbsProbKey.FORWARD)],
                                  mc._get(P.ABS_PROBS).names)
    np.testing.assert_array_equal(mc.adata.uns[_colors(AbsProbKey.FORWARD)],
                                  mc._get(P.ABS_PROBS).colors)

    np.testing.assert_array_equal(mc.adata.obs[_pd(AbsProbKey.FORWARD)],
                                  mc._get(P.PRIME_DEG))

    assert_array_nan_equal(mc.adata.obs[TermStatesKey.FORWARD.s],
                           mc._get(P.TERM))
    np.testing.assert_array_equal(mc.adata.obs[_probs(TermStatesKey.FORWARD)],
                                  mc._get(P.TERM_PROBS))
Ejemplo n.º 5
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):

        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        # choosing this instead of property because GPCCA doesn't have property for FIN_ABS_PROBS
        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                logg.debug(
                    f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new names")
                names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                    map(lambda c: isinstance(c, str) and is_color_like(c),
                        colors)):
                logg.debug(
                    f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new colors")
                colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Ejemplo n.º 6
0
    def test_compute_initial_states_from_forward_normal_run(
            self, adata_large: AnnData):
        vk = VelocityKernel(
            adata_large,
            backward=False).compute_transition_matrix(softmax_scale=4)
        ck = ConnectivityKernel(adata_large,
                                backward=False).compute_transition_matrix()
        terminal_kernel = 0.8 * vk + 0.2 * ck

        mc = cr.tl.estimators.GPCCA(terminal_kernel)
        mc.compute_schur(n_components=10, method="krylov")

        mc.compute_macrostates(n_states=2, n_cells=5)
        obsm_keys = set(mc.adata.obsm.keys())
        expected = mc._get(P.COARSE_STAT_D).index[np.argmin(
            mc._get(P.COARSE_STAT_D))]

        mc._compute_initial_states(1)

        key = TermStatesKey.BACKWARD.s

        assert key in mc.adata.obs
        np.testing.assert_array_equal(mc.adata.obs[key].cat.categories,
                                      [expected])
        assert _probs(key) in mc.adata.obs
        assert _colors(key) in mc.adata.uns
        assert _lin_names(key) in mc.adata.uns

        # make sure that we don't write anything there - it's useless
        assert set(mc.adata.obsm.keys()) == obsm_keys
Ejemplo n.º 7
0
    def _write_terminal_states(self, time=None) -> None:
        self.adata.obs[self._term_key] = self._get(P.TERM)
        self.adata.obs[_probs(self._term_key)] = self._get(P.TERM_PROBS)

        self.adata.uns[_colors(self._term_key)] = self._get(A.TERM_COLORS)
        self.adata.uns[_lin_names(self._term_key)] = np.array(
            self._get(P.TERM).cat.categories
        )

        extra_msg = ""
        if getattr(self, A.TERM_ABS_PROBS.s, None) is not None and hasattr(
            self, "_term_abs_prob_key"
        ):
            # checking for None because terminal states can be set using `set_terminal_states`
            # without the probabilities in GPCCA
            self.adata.obsm[self._term_abs_prob_key] = self._get(A.TERM_ABS_PROBS)
            extra_msg = f"       `adata.obsm[{self._term_abs_prob_key!r}]`\n"

        logg.info(
            f"Adding `adata.obs[{_probs(self._term_key)!r}]`\n"
            f"       `adata.obs[{self._term_key!r}]`\n"
            f"{extra_msg}"
            f"       `.{P.TERM_PROBS}`\n"
            f"       `.{P.TERM}`\n"
            "    Finish",
            time=time,
        )
Ejemplo n.º 8
0
    def maybe_create_lineage(
        direction: Union[str, Direction], pretty_name: Optional[str] = None
    ):
        if isinstance(direction, Direction):
            lin_key = str(
                AbsProbKey.FORWARD
                if direction == Direction.FORWARD
                else AbsProbKey.BACKWARD
            )
        else:
            lin_key = direction

        pretty_name = "" if pretty_name is None else (pretty_name + " ")
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)

        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`")

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage names not found in `adata.uns[{names_key!r}]`, creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"    Lineage names are don't have the required length ({n_lineages}), creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("    Successfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                map(lambda c: is_color_like(c), adata.uns[colors_key])
            ):
                logg.warning(
                    f"    Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("    Successfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(
                adata.obsm[lin_key], names=names, colors=colors
            )
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`"
            )
Ejemplo n.º 9
0
def _check_renaming_no_write_terminal(mc: cr.tl.estimators.GPCCA) -> None:
    assert mc._get(P.TERM) is None
    assert mc._get(P.TERM_PROBS) is None
    assert mc._get(A.TERM_ABS_PROBS) is None

    assert TermStatesKey.FORWARD.s not in mc.adata.obs
    assert _probs(TermStatesKey.FORWARD.s) not in mc.adata.obs
    assert _colors(TermStatesKey.FORWARD.s) not in mc.adata.uns
    assert _lin_names(TermStatesKey.FORWARD.s) not in mc.adata.uns
Ejemplo n.º 10
0
def _assert_has_all_keys(adata: AnnData, direction: Direction):
    assert _transition(direction) in adata.obsp.keys()
    # check if it's not a dummy transition matrix
    assert not np.all(
        np.isclose(np.diag(adata.obsp[_transition(direction)].A), 1.0))
    assert f"{_transition(direction)}_params" in adata.uns.keys()

    if direction == Direction.FORWARD:
        assert str(AbsProbKey.FORWARD) in adata.obsm
        assert isinstance(adata.obsm[str(AbsProbKey.FORWARD)], cr.tl.Lineage)

        assert _colors(AbsProbKey.FORWARD) in adata.uns.keys()
        assert _lin_names(AbsProbKey.FORWARD) in adata.uns.keys()

        assert str(TermStatesKey.FORWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(TermStatesKey.FORWARD)])

        assert _probs(TermStatesKey.FORWARD) in adata.obs

        # check the correlations with all lineages have been computed
        lin_probs = adata.obsm[str(AbsProbKey.FORWARD)]
        np.in1d(
            [f"{str(DirPrefix.FORWARD)} {key}" for key in lin_probs.names],
            adata.var.keys(),
        ).all()

    else:
        assert str(AbsProbKey.BACKWARD) in adata.obsm
        assert isinstance(adata.obsm[str(AbsProbKey.BACKWARD)], cr.tl.Lineage)

        assert _colors(AbsProbKey.BACKWARD) in adata.uns.keys()
        assert _lin_names(AbsProbKey.BACKWARD) in adata.uns.keys()

        assert str(TermStatesKey.BACKWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(TermStatesKey.BACKWARD)])

        assert _probs(TermStatesKey.BACKWARD) in adata.obs

        # check the correlations with all lineages have been computed
        lin_probs = adata.obsm[str(AbsProbKey.BACKWARD)]
        np.in1d(
            [f"{str(DirPrefix.BACKWARD)} {key}" for key in lin_probs.names],
            adata.var.keys(),
        ).all()
Ejemplo n.º 11
0
    def test_no_colors(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors_key = _colors(lin_key)
        del adata.uns[colors_key]

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins = adata_new.obsm[lin_key]

        assert isinstance(lins, Lineage)
        np.testing.assert_array_equal(lins.colors, _create_categorical_colors(n_lins))
        np.testing.assert_array_equal(lins.colors, adata_new.uns[colors_key])
Ejemplo n.º 12
0
    def _read_from_adata(self) -> None:
        self._set_or_debug(f"eig_{self._direction}", self.adata.uns, "_eig")

        self._set_or_debug(self._g2m_key, self.adata.obs, "_G2M_score")
        self._set_or_debug(self._s_key, self.adata.obs, "_S_score")

        self._set_or_debug(self._term_key, self.adata.obs, A.TERM.s)
        self._set_or_debug(_probs(self._term_key), self.adata.obs, A.TERM_PROBS)
        self._set_or_debug(_colors(self._term_key), self.adata.uns, A.TERM_COLORS)

        self._reconstruct_lineage(A.ABS_PROBS, self._abs_prob_key)
        self._set_or_debug(_pd(self._abs_prob_key), self.adata.obs, A.PRIME_DEG)
Ejemplo n.º 13
0
    def _read_from_adata(self) -> None:
        self._set_or_debug(f"eig_{self._direction}", self.adata.uns, "_eig")

        self._set_or_debug(self._g2m_key, self.adata.obs, "_G2M_score")
        self._set_or_debug(self._s_key, self.adata.obs, "_S_score")

        self._set_or_debug(self._fs_key, self.adata.obs, A.FIN.s)
        self._set_or_debug(_probs(self._fs_key), self.adata.obs, A.FIN_PROBS)
        self._set_or_debug(_colors(self._fs_key), self.adata.uns, A.FIN_COLORS)

        self._reconstruct_lineage(A.ABS_PROBS, self._abs_prob_key)
        self._set_or_debug(_dp(self._abs_prob_key), self.adata.obs, A.DIFF_POT)
Ejemplo n.º 14
0
    def test_normal_run(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors = _create_categorical_colors(10)[-n_lins:]
        names = [f"foo {i}" for i in range(n_lins)]

        adata.uns[_colors(lin_key)] = colors
        adata.uns[_lin_names(lin_key)] = names

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins_new = adata_new.obsm[lin_key]

        np.testing.assert_array_equal(lins_new.colors, colors)
        np.testing.assert_array_equal(lins_new.names, names)
Ejemplo n.º 15
0
    def test_compute_approx_normal_run(self, adata_large: AnnData):
        vk = VelocityKernel(adata_large).compute_transition_matrix(softmax_scale=4)
        ck = ConnectivityKernel(adata_large).compute_transition_matrix()
        terminal_kernel = 0.8 * vk + 0.2 * ck

        mc = cr.tl.estimators.CFLARE(terminal_kernel)
        mc.compute_eigendecomposition(k=5)
        mc.compute_terminal_states(use=2)

        assert is_categorical_dtype(mc._get(P.TERM))
        assert mc._get(P.TERM_PROBS) is not None

        assert TermStatesKey.FORWARD.s in mc.adata.obs.keys()
        assert _probs(TermStatesKey.FORWARD) in mc.adata.obs.keys()
        assert _colors(TermStatesKey.FORWARD) in mc.adata.uns.keys()
Ejemplo n.º 16
0
    def _write_initial_states(self,
                              membership: Lineage,
                              probs: pd.Series,
                              cats: pd.Series,
                              time=None) -> None:
        key = TermStatesKey.BACKWARD.s

        self.adata.obs[key] = cats
        self.adata.obs[_probs(key)] = probs

        self.adata.uns[_colors(key)] = membership.colors
        self.adata.uns[_lin_names(key)] = membership.names

        logg.info(
            f"Adding `adata.obs[{_probs(key)!r}]`\n       `adata.obs[{key!r}]`\n",
            time=time,
        )
Ejemplo n.º 17
0
    def _write_absorption_probabilities(
        self, time: datetime, extra_msg: str = ""
    ) -> None:
        self.adata.obsm[self._abs_prob_key] = self._get(P.ABS_PROBS)

        abs_prob = self._get(P.ABS_PROBS)

        self.adata.uns[_lin_names(self._abs_prob_key)] = abs_prob.names
        self.adata.uns[_colors(self._abs_prob_key)] = abs_prob.colors

        logg.info(
            f"Adding `adata.obsm[{self._abs_prob_key!r}]`\n"
            f"{extra_msg}"
            f"       `.{P.ABS_PROBS}`\n"
            "    Finish",
            time=time,
        )
Ejemplo n.º 18
0
    def _write_final_states(self, time=None) -> None:
        self.adata.obs[self._fs_key] = self._get(P.FIN)
        self.adata.obs[_probs(self._fs_key)] = self._get(P.FIN_PROBS)

        self.adata.uns[_colors(self._fs_key)] = self._get(A.FIN_COLORS)
        self.adata.uns[_lin_names(self._fs_key)] = list(self._get(P.FIN).cat.categories)

        extra_msg = ""
        if getattr(self, A.FIN_ABS_PROBS.s, None) is not None and hasattr(
            self, "_fin_abs_prob_key"
        ):
            # checking for None because final states can be set using `set_final_states`
            # without the probabilities in GPCCA
            self.adata.obsm[self._fin_abs_prob_key] = self._get(A.FIN_ABS_PROBS)
            extra_msg = f"       `adata.obsm[{self._fin_abs_prob_key!r}]`\n"

        logg.info(
            f"Adding `adata.obs[{_probs(self._fs_key)!r}]`\n"
            f"       `adata.obs[{self._fs_key!r}]`\n"
            f"{extra_msg}"
            f"       `.{P.FIN_PROBS}`\n"
            f"       `.{P.FIN}`",
            time=time,
        )
Ejemplo n.º 19
0
def graph(
        data: Union[AnnData, np.ndarray, spmatrix],
        graph_key: Optional[str] = None,
        ixs: Optional[np.array] = None,
        layout: Union[str, Dict, Callable] = "umap",
        keys: Sequence[KEYS] = ("incoming", ),
        keylocs: Union[KEYLOCS, Sequence[KEYLOCS]] = "uns",
        node_size: float = 400,
        labels: Optional[Union[Sequence[str], Sequence[Sequence[str]]]] = None,
        top_n_edges: Optional[Union[int, Tuple[int, bool, str]]] = None,
        self_loops: bool = True,
        self_loop_radius_frac: Optional[float] = None,
        filter_edges: Optional[Tuple[float, float]] = None,
        edge_reductions: Union[Callable, Sequence[Callable]] = np.sum,
        edge_weight_scale: float = 10,
        edge_width_limit: Optional[float] = None,
        edge_alpha: float = 1.0,
        edge_normalize: bool = False,
        edge_use_curved: bool = True,
        show_arrows: bool = True,
        font_size: int = 12,
        font_color: str = "black",
        color_nodes: bool = True,
        cat_cmap: ListedColormap = cm.Set3,
        cont_cmap: ListedColormap = cm.viridis,
        legend_loc: Optional[str] = "best",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        layout_kwargs: Dict = MappingProxyType({}),
) -> None:
    """
    Plot a graph, visualizing incoming and outgoing edges or self-transitions.

    This is a utility function to look in more detail at the transition matrix in areas of interest, e.g. around an
    endpoint of development. This function is meant to visualise a small subset of nodes (~100-500) and the most likely
    transitions between them. Note that limiting edges visualized using ``top_n_edges`` will speed things up,
    as well as reduce the visual clutter.

    Parameters
    ----------
    data
        The graph data to be plotted.
    graph_key
        Key in ``adata.obsp`` or ``adata.uns`` where the graph is stored. Only used
        when ``data`` is :class:`~anndata.Anndata` object.
    ixs
        Subset of indices of the graph to visualize.
    layout
        Layout to use for graph drawing.

        - If :class:`str`, search for embedding in ``adata.obsm['X_{layout}']``.
          Use ``layout_kwargs={'components': [0, 1]}`` to select components.
        - If :class:`dict`, keys should be values in interval ``[0, len(ixs))``
          and values `(x, y)` pairs corresponding to node positions.
    keys
        Keys in ``adata.obs``, ``adata.obsm`` or ``adata.obsp`` to color the nodes.

        - If `'incoming'`, `'outgoing'` or `'self_loops'`, visualize reduction (see ``edge_reductions``)
          for each node based on incoming or outgoing edges, respectively.
    keylocs
        Locations of ``keys``. Can be any attribute of ``data`` if it's :class:`anndata.AnnData` object.
    node_size
        Size of the nodes.
    labels
        Labels of the nodes.
    top_n_edges
        Either top N outgoing edges in descending order or a tuple
        ``(top_n_edges, in_ascending_order, {'incoming', 'outgoing'})``. If `None`, show all edges.
    self_loops
        Whether visualize self transitions and also to consider them in ``top_n_edges``.
    self_loop_radius_frac
        Fraction of a unit circle to visualize self transitions. If `None`, use ``node_size / 1000``.
    filter_edges
        Whether to remove all edges not in `[min, max]` interval.
    edge_reductions
        Aggregation function to use when coloring nodes by edge weights.
    edge_weight_scale
        Number by which to scale the width of the edges. Useful when the weights are small.
    edge_width_limit
        Upper bound for the width of the edges. Useful when weights are unevenly distributed.
    edge_alpha
        Alpha channel value for edges and arrows.
    edge_normalize
        If `True`, normalize edges to `[0, 1]` interval prior to applying any scaling or truncation.
    edge_use_curved
        If `True`, use curved edges. This can improve visualization at a small performance cost.
    show_arrows
        Whether to show the arrows. Setting this to `False` may dramatically speed things up.
    font_size
        Font size for node labels.
    font_color
        Label color of the nodes.
    color_nodes
        Whether to color the nodes
    cat_cmap
        Categorical colormap used when ``keys`` contain categorical variables.
    cont_cmap
        Continuous colormap used when ``keys`` contain continuous variables.
    legend_loc
        Location of the legend.
    %(plotting)s
    layout_kwargs
        Additional kwargs for ``layout``.

    Returns
    -------
    %(just_plots)s
    """

    from anndata import AnnData as _AnnData

    import networkx as nx

    def plot_arrows(curves, G, pos, ax, edge_weight_scale):
        for line, (edge, val) in zip(curves, G.edges.items()):
            if edge[0] == edge[1]:
                continue

            mask = (~np.isnan(line)).all(axis=1)
            line = line[mask, :]
            if not len(line):  # can be all NaNs
                continue

            line = line.reshape((-1, 2))
            X, Y = line[:, 0], line[:, 1]

            node_start = pos[edge[0]]
            # reverse
            if np.where(np.isclose(node_start - line,
                                   [0, 0]).all(axis=1))[0][0]:
                X, Y = X[::-1], Y[::-1]

            mid = len(X) // 2
            posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2])  # noqa

            arrow = FancyArrowPatch(
                posA=posA,
                posB=posB,
                # we clip because too small values
                # cause it to crash
                arrowstyle=ArrowStyle.CurveFilledB(
                    head_length=np.clip(
                        val["weight"] * edge_weight_scale * 4,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                    head_width=np.clip(
                        val["weight"] * edge_weight_scale * 2,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                ),
                color="k",
                zorder=float("inf"),
                alpha=edge_alpha,
                linewidth=0,
            )
            ax.add_artist(arrow)

    def normalize_weights():
        weights = np.array([v["weight"] for v in G.edges.values()])
        minn = np.min(weights)
        weights = (weights - minn) / (np.max(weights) - minn)
        for v, w in zip(G.edges.values(), weights):
            v["weight"] = w

    def remove_top_n_edges():
        if top_n_edges is None:
            return

        if isinstance(top_n_edges, (tuple, list)):
            to_keep, ascending, group_by = top_n_edges
        else:
            to_keep, ascending, group_by = top_n_edges, False, "out"

        if group_by not in ("incoming", "outgoing"):
            raise ValueError(
                "Argument `groupby` in `top_n_edges` must be either `'incoming`' or `'outgoing'`."
            )

        source, target = zip(*G.edges)
        weights = [v["weight"] for v in G.edges.values()]
        tmp = pd.DataFrame({
            "outgoing": source,
            "incoming": target,
            "w": weights
        })

        if not self_loops:
            # remove self loops
            tmp = tmp[tmp["incoming"] != tmp["outgoing"]]

        to_keep = set(
            map(
                tuple,
                tmp.groupby(group_by).apply(
                    lambda g: g.sort_values("w", ascending=ascending).take(
                        range(min(to_keep, len(g)))))[["outgoing",
                                                       "incoming"]].values,
            ))

        for e in list(G.edges):
            if e not in to_keep:
                G.remove_edge(*e)

    def remove_low_weight_edges():
        if filter_edges is None or filter_edges == (None, None):
            return

        minn, maxx = filter_edges
        minn = minn if minn is not None else -np.inf
        maxx = maxx if maxx is not None else np.inf

        for e, attr in list(G.edges.items()):
            if attr["weight"] < minn or attr["weight"] > maxx:
                G.remove_edge(*e)

    _min_edge_weight = 0.00001

    if edge_width_limit is None:
        logg.debug("Not limiting width of edges")
        edge_width_limit = float("inf")

    if self_loop_radius_frac is None:
        self_loop_radius_frac = (node_size /
                                 2000 if node_size >= 200 else node_size /
                                 1000)
        logg.debug(
            f"Setting self loop radius fraction to `{self_loop_radius_frac}`")

    if not isinstance(keylocs, (tuple, list)):
        keylocs = [keylocs] * len(keys)
    elif len(keylocs) == 1:
        keylocs = keylocs * 3
    elif all(map(lambda k: k in ("incoming", "outgoing", "self_loops"), keys)):
        # don't care about keylocs since they are irrelevant
        logg.debug("Ignoring key locations")
        keylocs = [None] * len(keys)

    if not isinstance(edge_reductions, (tuple, list)):
        edge_reductions = [edge_reductions] * len(keys)
    if not all(map(callable, edge_reductions)):
        raise ValueError("Not all `edge_reductions` functions are callable.")

    if not isinstance(labels, (tuple, list)):
        labels = [labels] * len(keys)
    elif not len(labels):
        labels = [None] * len(keys)
    elif not isinstance(labels[0], (tuple, list)):
        labels = [labels] * len(keys)

    if len(keys) != len(labels):
        raise ValueError(
            f"`Keys` and `labels` must be of the same shape, found `{len(keys)}` and `{len(labels)}`."
        )

    if isinstance(data, _AnnData):
        if graph_key is None:
            raise ValueError(
                "Argument `graph_key` cannot be `None` when `data` is `anndata.Anndata` object."
            )
        gdata = _read_graph_data(data, graph_key)
    elif isinstance(data, (np.ndarray, spmatrix)):
        gdata = data
    else:
        raise TypeError(
            f"Expected argument `data` to be one of `anndata.AnnData`, `numpy.ndarray`, `scipy.sparse.spmatrix`, "
            f"found `{type(data).__name__!r}`.")
    is_sparse = issparse(gdata)

    if ixs is not None:
        gdata = gdata[ixs, :][:, ixs]
    else:
        ixs = list(range(gdata.shape[0]))

    start = logg.info("Creating graph")
    G = (nx.from_scipy_sparse_matrix(gdata, create_using=nx.DiGraph)
         if is_sparse else nx.from_numpy_array(gdata, create_using=nx.DiGraph))

    remove_low_weight_edges()
    remove_top_n_edges()
    if edge_normalize:
        normalize_weights()
    logg.info("    Finish", time=start)

    # do NOT recreate the graph, for the edge reductions
    # gdata = nx.to_numpy_array(G)

    if figsize is None:
        figsize = (12, 8 * len(keys))

    fig, axes = plt.subplots(nrows=len(keys),
                             ncols=1,
                             figsize=figsize,
                             dpi=dpi)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = np.ravel(axes)

    if isinstance(layout, str):
        if f"X_{layout}" not in data.obsm:
            raise KeyError(
                f"Unable to find embedding `'X_{layout}'` in `adata.obsm`.")
        components = layout_kwargs.get("components", [0, 1])
        if len(components) != 2:
            raise ValueError(
                f"Components in `layout_kwargs` must be of length `2`, found `{len(components)}`."
            )
        emb = data.obsm[f"X_{layout}"][:, components]
        pos = {i: emb[ix, :] for i, ix in enumerate(ixs)}
        logg.info(f"Embedding graph using `{layout!r}` layout")
    elif isinstance(layout, dict):
        rng = range(len(ixs))
        for k, v in layout.items():
            if k not in rng:
                raise ValueError(
                    f"Key in `layout` must be in `range(len(ixs))`, found `{k}`."
                )
            if len(v) != 2:
                raise ValueError(
                    f"Value in `layout` must be a `tuple` or a `list` of length 2, found `{len(v)}`."
                )
        pos = layout
        logg.debug("Using precomputed layout")
    elif callable(layout):
        start = logg.info(
            f"Embedding graph using `{layout.__name__!r}` layout")
        pos = layout(G, **layout_kwargs)
        logg.info("    Finish", time=start)
    else:
        raise TypeError(f"Argument `layout` must be either a `string`, "
                        f"a `dict` or a `callable`, found `{type(layout)}`.")

    curves, lc = None, None
    if edge_use_curved:
        try:
            from ._utils import _curved_edges

            logg.debug("Creating curved edges")
            curves = _curved_edges(G,
                                   pos,
                                   self_loop_radius_frac,
                                   polarity="directed")
            lc = LineCollection(
                curves,
                colors="black",
                linewidths=np.clip(
                    np.ravel([v["weight"]
                              for v in G.edges.values()]) * edge_weight_scale,
                    0,
                    edge_width_limit,
                ),
                alpha=edge_alpha,
            )
        except ImportError as e:
            global _msg_shown
            if not _msg_shown:
                print(
                    str(e)[:-1],
                    "in order to use curved edges or specify `edge_use_curved=False`.",
                )
                _msg_shown = True

    for ax, keyloc, key, labs, er in zip(axes, keylocs, keys, labels,
                                         edge_reductions):
        label_col = {}  # dummy value

        if key in ("incoming", "outgoing", "self_loops"):
            if key in ("incoming", "outgoing"):
                vals = er(gdata, axis=int(key == "outgoing"))
                if issparse(vals):
                    vals = vals.A
                vals = vals.flatten()
            else:
                vals = gdata.diagonal() if is_sparse else np.diag(gdata)
            node_v = dict(zip(pos.keys(), vals))
        else:
            label_col = getattr(data, keyloc)
            if key in label_col:
                node_v = dict(zip(pos.keys(), label_col[key]))
            else:
                raise RuntimeError(
                    f"Key `{key!r}` not found in `adata.{keyloc}`.")

        if labs is not None:
            if len(labs) != len(pos):
                raise RuntimeError(
                    f"Number of labels ({len(labels)}) and nodes ({len(pos)}) mismatch."
                )
            nx.draw_networkx_labels(
                G,
                pos,
                labels=labs if isinstance(labs, dict) else dict(
                    zip(pos.keys(), labs)),
                ax=ax,
                font_color=font_color,
                font_size=font_size,
            )

        if lc is not None and curves is not None:
            ax.add_collection(deepcopy(lc))  # copying necessary
            if show_arrows:
                plot_arrows(curves, G, pos, ax, edge_weight_scale)
        else:
            nx.draw_networkx_edges(
                G,
                pos,
                width=[
                    np.clip(
                        v["weight"] * edge_weight_scale,
                        _min_edge_weight,
                        edge_width_limit,
                    ) for _, v in G.edges.items()
                ],
                alpha=edge_alpha,
                edge_color="black",
                arrows=True,
                arrowstyle="-|>",
            )

        if key in label_col and is_categorical_dtype(label_col[key]):
            values = label_col[key]
            if keyloc in ("obs", "obsm"):
                values = values[ixs]
            categories = values.cat.categories
            color_key = _colors(key)
            if color_key in data.uns:
                mapper = dict(zip(categories, data.uns[color_key]))
            else:
                mapper = dict(
                    zip(categories, map(cat_cmap.get, range(len(categories)))))

            colors = []
            seen = set()

            for v in values:
                colors.append(mapper[v])
                seen.add(v)

            nodes_kwargs = dict(cmap=cat_cmap, node_color=colors)  # noqa
            if legend_loc is not None:
                x, y = pos[0]
                for label in sorted(seen):
                    ax.plot([x], [y], label=label, color=mapper[label])
                ax.legend(loc=legend_loc)
        else:
            values = list(node_v.values())
            vmin, vmax = np.min(values), np.max(values)
            nodes_kwargs = dict(  # noqa
                cmap=cont_cmap,
                node_color=values,
                vmin=vmin,
                vmax=vmax)

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="1.5%", pad=0.05)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          cmap=cont_cmap,
                                          norm=mpl.colors.Normalize(vmin=vmin,
                                                                    vmax=vmax))

        if color_nodes is False:
            nodes_kwargs = {}

        nx.draw_networkx_nodes(G,
                               pos,
                               node_size=node_size,
                               ax=ax,
                               **nodes_kwargs)

        ax.set_title(key)
        ax.axis("off")

    if save is not None:
        save_fig(fig, save)

    fig.show()
Ejemplo n.º 20
0
    def _set_categorical_labels(
        self,
        attr_key: str,
        color_key: str,
        pretty_attr_key: str,
        categories: Union[Series, Dict[Any, Any]],
        add_to_existing_error_msg: Optional[str] = None,
        cluster_key: Optional[str] = None,
        en_cutoff: Optional[float] = None,
        p_thresh: Optional[float] = None,
        add_to_existing: bool = False,
    ) -> None:
        if isinstance(categories, dict):
            categories = _convert_to_categorical_series(
                categories, list(self.adata.obs_names)
            )
        if not is_categorical_dtype(categories):
            raise TypeError(
                f"Object must be `categorical`, found `{infer_dtype(categories)}`."
            )

        if add_to_existing:
            if getattr(self, attr_key) is None:
                raise RuntimeError(add_to_existing_error_msg)
            categories = _merge_categorical_series(
                getattr(self, attr_key), categories, inplace=False
            )

        if cluster_key is not None:
            logg.debug(f"Creating colors based on `{cluster_key}`")

            # check that we can load the reference series from adata
            if cluster_key not in self.adata.obs:
                raise KeyError(
                    f"Cluster key `{cluster_key!r}` not found in `adata.obs`."
                )
            series_query, series_reference = categories, self.adata.obs[cluster_key]

            # load the reference colors if they exist
            if _colors(cluster_key) in self.adata.uns.keys():
                colors_reference = _convert_to_hex_colors(
                    self.adata.uns[_colors(cluster_key)]
                )
            else:
                colors_reference = _create_categorical_colors(
                    len(series_reference.cat.categories)
                )

            approx_rcs_names, colors = _map_names_and_colors(
                series_reference=series_reference,
                series_query=series_query,
                colors_reference=colors_reference,
                en_cutoff=en_cutoff,
            )
            setattr(self, color_key, colors)
            # if approx_rcs_names is categorical, the info is take from .cat.categories
            categories.cat.categories = approx_rcs_names
        else:
            setattr(
                self,
                color_key,
                _create_categorical_colors(len(categories.cat.categories)),
            )

        if p_thresh is not None:
            self._detect_cc_stages(categories, p_thresh=p_thresh)

        # write to class and adata
        if getattr(self, attr_key) is not None:
            logg.debug(f"Overwriting `.{pretty_attr_key}`")

        setattr(self, attr_key, categories)