コード例 #1
0
def _assert_has_all_keys(adata: AnnData, direction: Direction):
    assert _transition(direction) in adata.uns.keys()

    if direction == Direction.FORWARD:
        assert str(LinKey.FORWARD) in adata.obsm
        assert isinstance(adata.obsm[str(LinKey.FORWARD)], cr.tl.Lineage)

        assert _colors(LinKey.FORWARD) in adata.uns.keys()
        assert _lin_names(LinKey.FORWARD) in adata.uns.keys()

        assert str(StateKey.FORWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(StateKey.FORWARD)])

        assert _probs(StateKey.FORWARD) in adata.obs
    else:
        assert str(LinKey.BACKWARD) in adata.obsm
        assert isinstance(adata.obsm[str(LinKey.BACKWARD)], cr.tl.Lineage)

        assert _colors(LinKey.BACKWARD) in adata.uns.keys()
        assert _lin_names(LinKey.BACKWARD) in adata.uns.keys()

        assert str(StateKey.BACKWARD) in adata.obs
        assert is_categorical_dtype(adata.obs[str(StateKey.BACKWARD)])

        assert _probs(StateKey.BACKWARD) in adata.obs
コード例 #2
0
    def test_compute_lin_probs_normal_run(self, adata_large: AnnData):
        vk = VelocityKernel(adata_large).compute_transition_matrix()
        ck = ConnectivityKernel(adata_large).compute_transition_matrix()
        final_kernel = 0.8 * vk + 0.2 * ck

        mc = cr.tl.MarkovChain(final_kernel)
        mc.compute_eig(k=5)
        mc.compute_approx_rcs(use=2)
        mc.compute_lin_probs()

        assert isinstance(mc.diff_potential, np.ndarray)
        assert f"{LinKey.FORWARD}_dp" in mc.adata.obs.keys()
        np.testing.assert_array_equal(mc.diff_potential,
                                      mc.adata.obs[f"{LinKey.FORWARD}_dp"])

        assert isinstance(mc.lineage_probabilities, cr.tl.Lineage)
        assert mc.lineage_probabilities.shape == (mc.adata.n_obs, 2)
        assert f"{LinKey.FORWARD}" in mc.adata.obsm.keys()
        np.testing.assert_array_equal(mc.lineage_probabilities.X,
                                      mc.adata.obsm[f"{LinKey.FORWARD}"])

        assert _lin_names(LinKey.FORWARD) in mc.adata.uns.keys()
        np.testing.assert_array_equal(mc.lineage_probabilities.names,
                                      mc.adata.uns[_lin_names(LinKey.FORWARD)])

        assert _colors(LinKey.FORWARD) in mc.adata.uns.keys()
        np.testing.assert_array_equal(mc.lineage_probabilities.colors,
                                      mc.adata.uns[_colors(LinKey.FORWARD)])
        np.testing.assert_allclose(mc.lineage_probabilities.X.sum(1), 1)
コード例 #3
0
ファイル: test_gpcca.py プロジェクト: dpeerlab/cellrank
def _check_main_states(mc: cr.tl.GPCCA, has_main_states: bool = True):
    if has_main_states:
        assert isinstance(mc.main_states, pd.Series)
        assert_array_nan_equal(mc.adata.obs[str(StateKey.FORWARD)], mc.main_states)
        np.testing.assert_array_equal(
            mc.adata.uns[_colors(StateKey.FORWARD)],
            mc.lineage_probabilities[list(mc.main_states.cat.categories)].colors,
        )

    assert isinstance(mc.diff_potential, np.ndarray)
    assert isinstance(mc.lineage_probabilities, cr.tl.Lineage)

    np.testing.assert_array_equal(
        mc.adata.obsm[str(LinKey.FORWARD)], mc.lineage_probabilities.X
    )
    np.testing.assert_array_equal(
        mc.adata.uns[_lin_names(LinKey.FORWARD)], mc.lineage_probabilities.names
    )
    np.testing.assert_array_equal(
        mc.adata.uns[_colors(LinKey.FORWARD)], mc.lineage_probabilities.colors
    )

    np.testing.assert_array_equal(mc.adata.obs[_dp(LinKey.FORWARD)], mc.diff_potential)
    np.testing.assert_array_equal(
        mc.adata.obs[_probs(StateKey.FORWARD)], mc.main_states_probabilities
    )
コード例 #4
0
ファイル: test_gpcca.py プロジェクト: dpeerlab/cellrank
def _check_compute_meta(mc: cr.tl.GPCCA) -> None:
    assert mc.lineage_probabilities is None
    assert isinstance(mc._meta_lin_probs, cr.tl.Lineage)

    assert isinstance(mc.metastable_states, pd.Series)
    assert_array_nan_equal(mc.metastable_states, mc.adata.obs[str(MetaKey.FORWARD)])

    np.testing.assert_array_equal(mc._meta_states_colors, mc._meta_lin_probs.colors)
    np.testing.assert_array_equal(
        mc._meta_states_colors, mc.adata.uns[_colors(str(MetaKey.FORWARD))]
    )

    if "stationary_dist" in mc.eigendecomposition:
        assert mc._coarse_init_dist is None
        assert mc._schur_matrix is None
        assert mc.coarse_stationary_distribution is None
        assert mc.coarse_T is None
        assert mc.schur_vectors is None
    else:
        assert isinstance(mc._coarse_init_dist, pd.Series)
        assert isinstance(mc._schur_matrix, np.ndarray)
        assert mc.coarse_stationary_distribution is None or isinstance(
            mc.coarse_stationary_distribution, pd.Series
        )
        assert isinstance(mc.coarse_T, pd.DataFrame)
        assert isinstance(mc.schur_vectors, np.ndarray)
コード例 #5
0
ファイル: test_read.py プロジェクト: dpeerlab/cellrank
    def test_no_colors(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors_key = _colors(lin_key)
        del adata.uns[colors_key]

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins = adata_new.obsm[lin_key]

        assert isinstance(lins, Lineage)
        np.testing.assert_array_equal(lins.colors, _create_categorical_colors(n_lins))
        np.testing.assert_array_equal(lins.colors, adata_new.uns[colors_key])
コード例 #6
0
    def test_compute_approx_normal_run(self, adata_large: AnnData):
        vk = VelocityKernel(adata_large).compute_transition_matrix()
        ck = ConnectivityKernel(adata_large).compute_transition_matrix()
        final_kernel = 0.8 * vk + 0.2 * ck

        mc = cr.tl.MarkovChain(final_kernel)
        mc.compute_eig(k=5)
        mc.compute_approx_rcs(use=2)

        assert is_categorical_dtype(mc.approx_recurrent_classes)
        assert mc.approx_recurrent_classes_probabilities is not None
        assert _colors(RcKey.FORWARD) in mc.adata.uns.keys()
        assert _probs(RcKey.FORWARD) in mc.adata.obs.keys()
コード例 #7
0
    def plot_metastable_states(
        self, cluster_key: Optional[str] = None, **kwargs
    ) -> None:
        """
        Plot the approximate recurrent classes in a given embedding.

        Params
        ------
        cluster_key
            Key from `.obs` to plot clusters.
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        None
            Nothing, just plots the approximate recurrent classes.
        """

        if self._meta_states is None:
            raise RuntimeError(
                "Compute approximate recurrent classes first as `.compute_metastable_states()`"
            )

        self._adata.obs[self._rc_key] = self._meta_states

        # check whether the length of the color array matches the number of clusters
        color_key = _colors(self._rc_key)
        if color_key in self._adata.uns and len(self._adata.uns[color_key]) != len(
            self._meta_states.cat.categories
        ):
            del self._adata.uns[_colors(self._rc_key)]
            self._meta_states_colors = None

        color = self._rc_key if cluster_key is None else [cluster_key, self._rc_key]
        scv.pl.scatter(self._adata, color=color, **kwargs)

        if color_key in self._adata.uns:
            self._meta_states_colors = self._adata.uns[color_key]
コード例 #8
0
ファイル: test_read.py プロジェクト: dpeerlab/cellrank
    def test_normal_run(self, adata: AnnData, path: Path, lin_key: str, n_lins: int):
        colors = _create_categorical_colors(10)[-n_lins:]
        names = [f"foo {i}" for i in range(n_lins)]

        adata.uns[_colors(lin_key)] = colors
        adata.uns[_lin_names(lin_key)] = names

        sc.write(path, adata)
        adata_new = cr.read(path)
        lins_new = adata_new.obsm[lin_key]

        np.testing.assert_array_equal(lins_new.colors, colors)
        np.testing.assert_array_equal(lins_new.names, names)
コード例 #9
0
    def test_check_and_create_colors(self, adata_large):
        adata = adata_large
        vk = VelocityKernel(adata).compute_transition_matrix()
        ck = ConnectivityKernel(adata).compute_transition_matrix()
        final_kernel = 0.8 * vk + 0.2 * ck

        mc_fwd = cr.tl.CFLARE(final_kernel)
        mc_fwd.compute_partition()
        mc_fwd.compute_eig()

        mc_fwd.compute_metastable_states(use=3)

        mc_fwd._meta_states_colors = None
        del mc_fwd.adata.uns[_colors(StateKey.FORWARD)]

        mc_fwd._check_and_create_colors()

        assert _colors(StateKey.FORWARD) in mc_fwd.adata.uns
        np.testing.assert_array_equal(
            mc_fwd.adata.uns[_colors(StateKey.FORWARD)],
            _create_categorical_colors(3))
        np.testing.assert_array_equal(
            mc_fwd.adata.uns[_colors(StateKey.FORWARD)],
            mc_fwd._meta_states_colors)
コード例 #10
0
    def maybe_create_lineage(direction: Direction):
        lin_key = str(LinKey.FORWARD if direction ==
                      Direction.FORWARD else LinKey.BACKWARD)
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)
        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(
                f"Creating {'forward' if direction == Direction.FORWARD else 'backward'} `Lineage` object"
            )

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"Lineage names not found in `adata.uns[{names_key!r}]`, creating dummy names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"Lineage names are don't have the required length ({n_lineages}), creating dummy names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("Succesfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                    map(lambda c: is_color_like(c), adata.uns[colors_key])):
                logg.warning(
                    f"Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors")
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("Succesfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(adata.obsm[lin_key],
                                          names=names,
                                          colors=colors)
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"DEBUG: Unable to load {'forward' if direction == Direction.FORWARD else 'backward'} "
                f"`Lineage` from `adata.obsm[{lin_key!r}]`")
コード例 #11
0
    def _check_and_create_colors(self):
        n_cats = len(self._meta_states.cat.categories)
        color_key = _colors(self._rc_key)

        if self._meta_states_colors is None:
            if color_key in self._adata.uns and n_cats == len(
                self._adata.uns[color_key]
            ):
                logg.debug("DEBUG: Loading colors from `.adata` object")
                self._meta_states_colors = _convert_to_hex_colors(
                    self._adata.uns[color_key]
                )
            else:
                self._meta_states_colors = _create_categorical_colors(n_cats)
                self._adata.uns[color_key] = self._meta_states_colors
        elif len(self._meta_states_colors) != n_cats:
            self._meta_states_colors = _create_categorical_colors(n_cats)
            self._adata.uns[color_key] = self._meta_states_colors
コード例 #12
0
ファイル: _graph.py プロジェクト: dpeerlab/cellrank
def graph(
        data: Union[AnnData, np.ndarray, spmatrix],
        graph_key: Optional[str] = None,
        ixs: Optional[np.array] = None,
        layout: Union[str, Dict, Callable] = nx.kamada_kawai_layout,
        keys: Sequence[KEYS] = ("incoming", ),
        keylocs: Union[KEYLOCS, Sequence[KEYLOCS]] = "uns",
        node_size: float = 400,
        labels: Optional[Union[Sequence[str], Sequence[Sequence[str]]]] = None,
        top_n_edges: Optional[Union[int, Tuple[int, bool, str]]] = None,
        self_loops: bool = True,
        self_loop_radius_frac: Optional[float] = None,
        filter_edges: Optional[Tuple[float, float]] = None,
        edge_reductions: Union[Callable, Sequence[Callable]] = np.sum,
        edge_weight_scale: float = 10,
        edge_width_limit: Optional[float] = None,
        edge_alpha: float = 1.0,
        edge_normalize: bool = False,
        edge_use_curved: bool = True,
        show_arrows: bool = True,
        font_size: int = 12,
        font_color: str = "black",
        cat_cmap: ListedColormap = cm.Set3,
        cont_cmap: ListedColormap = cm.viridis,
        legend_loc: Optional[str] = "best",
        figsize: Tuple[float, float] = (15, 10),
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        layout_kwargs: Dict = MappingProxyType({}),
) -> None:
    """
    Plot a graph, visualizing incoming and outgoing edges or self-transitions.

    This is a utility function to look in more detail at the transition matrix in areas of interest, e.g. around an
    endpoint of development. This function is meant to visualise a small subset of nodes (~100-500) and the most likely
    transitions between them. Note that limiting edges visualized using :paramref:`top_n_edges` will speed things up,
    as well as reduce the visual clutter.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/graph.png
       :width: 400px
       :align: center

    Params
    ------
    data :
        The graph data, stored either in `.uns` [ :paramref:`graph_key` ], or as a sparse or a dense matrix.
    graph_key
        Key in :paramref:`adata` `.uns` where the graph is stored.
        Only used when :paramref:`adata` is :class:`Anndata` object.
    ixs
        Subset of indices of the graph to visualize.
    layout
        Layout to use for graph drawing.

        - If :class:`str`, search for embedding in :paramref:`adata` `.obsm[X_` :paramref:`layout` `]`.
          Use :paramref:`layout_kwargs` = `{'components': [0, 1]}` to select components.
        - If :class:`dict`, keys should be values in interval [0, len(:paramref:`ixs`))
          and values `(x, y)` pairs corresponding to node positions.
    keys
        Keys in :paramref:`adata` `.obs`, :paramref:`adata` `.obsm` or :paramref:`adata` `.uns` to color the nodes.

        - If `'incoming'`, `'outgoing'` or `'self_loops'` to
          visualize reduction (see :paramref:`edge_reductions`) for each node based
          on incoming or outgoing edges, respectively.
    keylocs
        Locations of :paramref:`keys`, can be `'obs'`, `'obsm'` or `'uns'`.
    node_size
        Size of the nodes.
    labels
        Labels of the nodes.
    top_n_edges
        Either top N outgoing edges in descending order or a tuple
        `(top_n_edges, in_ascending_order, {'incoming', 'outgoing'})`.
        If `None`, show all edges.
    self_loops
        Whether visualize self transitions and also to consider them in :paramref:`top_n_edges`.
    self_loop_radius_frac
        Fraction of a unit circle to visualize self transitions.

        If `None`, use :paramref:`node_size` / 1000.
    filter_edges
        Whether to remove all edges not in `[min, max]` interval.
    edge_reductions
        Aggregation function to use when coloring nodes by edge weights.
    edge_weight_scale
        Number by which to scale the width of the edges. Useful when the weights are small.
    edge_width_limit
        Upper bound for the width of the edges. Useful when weights are unevenly distributed.
    edge_alpha
        Alpha channel value for edges and arrows.
    edge_normalize
        If `True`, normalize edges to `[0, 1]` interval prior to applying any scaling or truncation.
    edge_use_curved
        If `True`, use curved edges. This can improve visualization at a small performance cost.
    show_arrows
        Whether to show the arrows. Setting this to `False` may dramatically speed things up.
    font_size
        Font size for node labels.
    font_color
        Label color of the nodes.
    cat_cmap
        Categorical colormap used when :paramref:`keys` contain categorical variables.
    cont_cmap
        Continuous colormap used when :paramref:`keys` contain continuous variables.
    legend_loc
        Location of the legend.
    figsize
        Size of the figure.
    dpi
        Dots per inch.
    save
        Filename where to save the plots.
        If `None`, just shows the plot.
    layout_kwargs
        Additional kwargs for :paramref:`layout`.

    Returns
    -------
    None
        Nothing, just plots the graph.
        Optionally saves the figure based on :paramref:`save`.
    """
    def plot_arrows(curves, G, pos, ax, edge_weight_scale):
        for line, (edge, val) in zip(curves, G.edges.items()):
            if edge[0] == edge[1]:
                continue

            mask = (~np.isnan(line)).all(axis=1)
            line = line[mask, :]
            if not len(line):  # can be all NaNs
                continue

            line = line.reshape((-1, 2))
            X, Y = line[:, 0], line[:, 1]

            node_start = pos[edge[0]]
            # reverse
            if np.where(np.isclose(node_start - line,
                                   [0, 0]).all(axis=1))[0][0]:
                X, Y = X[::-1], Y[::-1]

            mid = len(X) // 2
            posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2])  # noqa

            arrow = FancyArrowPatch(
                posA=posA,
                posB=posB,
                # we clip because too small values
                # cause it to crash
                arrowstyle=ArrowStyle.CurveFilledB(
                    head_length=np.clip(
                        val["weight"] * edge_weight_scale * 4,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                    head_width=np.clip(
                        val["weight"] * edge_weight_scale * 2,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                ),
                color="k",
                zorder=float("inf"),
                alpha=edge_alpha,
                linewidth=0,
            )
            ax.add_artist(arrow)

    def normalize_weights():
        weights = np.array([v["weight"] for v in G.edges.values()])
        minn = np.min(weights)
        weights = (weights - minn) / (np.max(weights) - minn)
        for v, w in zip(G.edges.values(), weights):
            v["weight"] = w

    def remove_top_n_edges():
        if top_n_edges is None:
            return

        if isinstance(top_n_edges, (tuple, list)):
            to_keep, ascending, group_by = top_n_edges
        else:
            to_keep, ascending, group_by = top_n_edges, False, "out"

        if group_by not in ("incoming", "outgoing"):
            raise ValueError(
                "Argument `groupby` in `top_n_edges` must be either `'incoming`' or `'outgoing'`."
            )

        source, target = zip(*G.edges)
        weights = [v["weight"] for v in G.edges.values()]
        tmp = pd.DataFrame({
            "outgoing": source,
            "incoming": target,
            "w": weights
        })

        if not self_loops:
            # remove self loops
            tmp = tmp[tmp["incoming"] != tmp["outgoing"]]

        to_keep = set(
            map(
                tuple,
                tmp.groupby(group_by).apply(
                    lambda g: g.sort_values("w", ascending=ascending).take(
                        range(min(to_keep, len(g)))))[["outgoing",
                                                       "incoming"]].values,
            ))

        for e in list(G.edges):
            if e not in to_keep:
                G.remove_edge(*e)

    def remove_low_weight_edges():
        if filter_edges is None or filter_edges == (None, None):
            return

        minn, maxx = filter_edges
        minn = minn if minn is not None else -np.inf
        maxx = maxx if maxx is not None else np.inf

        for e, attr in list(G.edges.items()):
            if attr["weight"] < minn or attr["weight"] > maxx:
                G.remove_edge(*e)

    _min_edge_weight = 0.00001

    if edge_width_limit is None:
        logg.debug("DEBUG: Not limiting width of edges")
        edge_width_limit = float("inf")

    if self_loop_radius_frac is None:
        self_loop_radius_frac = (node_size /
                                 2000 if node_size >= 200 else node_size /
                                 1000)
        logg.debug(
            f"Setting self loop radius fraction to `{self_loop_radius_frac}`")

    if not isinstance(keylocs, (tuple, list)):
        keylocs = [keylocs] * len(keys)
    elif len(keylocs) == 1:
        keylocs = keylocs * 3
    elif all(map(lambda k: k in ("incoming", "outgoing", "self_loops"), keys)):
        # don't care about keylocs since they are irrelevant
        logg.debug("DEBUG: Ignoring key locations")
        keylocs = [None] * len(keys)

    for k in ("obs", "obsm"):
        if k in keylocs and ixs is None:
            raise ValueError(
                f"Invalid combination: `ixs` is None and found `{k!r}` in `keylocs`."
            )

    if not isinstance(edge_reductions, (tuple, list)):
        edge_reductions = [edge_reductions] * len(keys)
    if not all(map(callable, edge_reductions)):
        raise ValueError("Not all edge_reductions functions are callable.")

    if not isinstance(labels, (tuple, list)):
        labels = [labels] * len(keys)
    elif not len(labels):
        labels = [None] * len(keys)
    elif not isinstance(labels[0], (tuple, list)):
        labels = [labels] * len(keys)

    if len(labels) != len(keys):
        raise ValueError("`Keys` and `labels` must be of the same shape.")

    if isinstance(data, AnnData):
        if graph_key is None:
            raise ValueError(
                "Argument `graph_key` cannot be `None` when `adata` is `anndata.Anndata` object."
            )
        gdata = data.uns[graph_key]["T"]
    elif isinstance(data, (np.ndarray, spmatrix)):
        gdata = data
    else:
        raise TypeError(
            f"Expected argument `data` to be one of `AnnData`, `numpy.ndarray`, `scipy.sparse.spmatrix`, "
            f"found `{type(data).__name__}`")
    is_sparse = issparse(gdata)

    if ixs is not None:
        gdata = gdata[ixs, :][:, ixs]
    else:
        ixs = list(range(gdata.shape[0]))

    start = logg.info("Creating graph")
    G = (nx.from_scipy_sparse_matrix(gdata, create_using=nx.DiGraph)
         if is_sparse else nx.from_numpy_array(gdata, create_using=nx.DiGraph))

    remove_low_weight_edges()
    remove_top_n_edges()
    if edge_normalize:
        normalize_weights()
    logg.info("    Finish", time=start)

    # do NOT recreate the graph, for the edge reductions
    # gdata = nx.to_numpy_array(G)

    fig, axes = plt.subplots(nrows=len(keys),
                             ncols=1,
                             figsize=figsize,
                             dpi=dpi)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = np.ravel(axes)

    if isinstance(layout, str):
        if f"X_{layout}" not in data.obsm:
            raise KeyError(
                f"Unable to find embedding `'X_{layout}'` in `adata.obsm`.")
        components = layout_kwargs.get("components", [0, 1])
        if len(components) != 2:
            raise ValueError(
                f"Components in `layout_kwargs` must be of length `2`, found `{len(components)}`."
            )
        emb = data.obsm[f"X_{layout}"][:, components]
        pos = {i: emb[ix, :] for i, ix in enumerate(ixs)}
        logg.info(f"Embedding graph using `{layout!r}` layout")
    elif isinstance(layout, dict):
        rng = range(len(ixs))
        for k, v in layout.items():
            if k not in rng:
                raise ValueError(
                    f"Key in `layout` must be in `range(len(ixs))`, found `{k}`."
                )
            if len(v) != 2:
                raise ValueError(
                    f"Value in `layout` must be a tuple or list of length 2, found `{v}`."
                )
        pos = layout
        logg.debug("DEBUG: Using pre-specified layout")
    elif callable(layout):
        start = logg.info(
            f"Embedding graph using `{layout.__name__!r}` layout")
        pos = layout(G, **layout_kwargs)
        logg.info("    Finish", time=start)
    else:
        raise TypeError(f"Argument `layout` must be either a `string`, "
                        f"a `dict` or a `callable`, found `{type(layout)}`.")

    curves, lc = None, None
    if edge_use_curved:
        try:
            from ._utils import curved_edges

            logg.debug("DEBUG: Creating curved edges")
            curves = curved_edges(G,
                                  pos,
                                  self_loop_radius_frac,
                                  polarity="directed")
            lc = LineCollection(
                curves,
                colors="black",
                linewidths=np.clip(
                    np.ravel([v["weight"]
                              for v in G.edges.values()]) * edge_weight_scale,
                    0,
                    edge_width_limit,
                ),
                alpha=edge_alpha,
            )
        except ImportError as e:
            global _msg_shown
            if not _msg_shown:
                print(
                    str(e)[:-1],
                    "in order to use curved edges or specify `edge_use_curved=False`.",
                )
                _msg_shown = True

    for ax, keyloc, key, labs, er in zip(axes, keylocs, keys, labels,
                                         edge_reductions):
        label_col = {}  # dummy value

        if key in ("incoming", "outgoing", "self_loops"):
            if key in ("incoming", "outgoing"):
                vals = np.array(er(gdata,
                                   axis=int(key == "outgoing"))).flatten()
            else:
                vals = gdata.diagonal() if is_sparse else np.diag(gdata)
            node_v = dict(zip(pos.keys(), vals))
        else:
            label_col = getattr(data, keyloc)
            if key in label_col:
                node_v = dict(zip(pos.keys(), label_col[key]))
            else:
                raise RuntimeError(
                    f"Key `{key!r}` not found in `adata.{keyloc!r}`.")

        if labs is not None:
            if len(labs) != len(pos):
                raise RuntimeError(
                    f"Number of labels ({len(labels)}) and nodes ({len(pos)}) mismatch."
                )
            nx.draw_networkx_labels(
                G,
                pos,
                labels=labs if isinstance(labs, dict) else dict(
                    zip(pos.keys(), labs)),
                ax=ax,
                font_color=font_color,
                font_size=font_size,
            )

        if lc is not None and curves is not None:
            ax.add_collection(deepcopy(lc))  # copying necessary
            if show_arrows:
                plot_arrows(curves, G, pos, ax, edge_weight_scale)
        else:
            nx.draw_networkx_edges(
                G,
                pos,
                width=[
                    np.clip(
                        v["weight"] * edge_weight_scale,
                        _min_edge_weight,
                        edge_width_limit,
                    ) for _, v in G.edges.items()
                ],
                alpha=edge_alpha,
                edge_color="black",
                arrows=True,
                arrowstyle="-|>",
            )

        if key in label_col and is_categorical_dtype(label_col[key]):
            values = label_col[key]
            if keyloc in ("obs", "obsm"):
                values = values[ixs]
            categories = values.cat.categories
            color_key = _colors(key)
            if color_key in data.uns:
                mapper = dict(zip(categories, data.uns[color_key]))
            else:
                mapper = dict(
                    zip(categories, map(cat_cmap.get, range(len(categories)))))

            colors = []
            seen = set()

            for v in values:
                colors.append(mapper[v])
                seen.add(v)

            nodes_kwargs = dict(cmap=cat_cmap, node_color=colors)  # noqa
            if legend_loc is not None:
                x, y = pos[0]
                for label in sorted(seen):
                    ax.plot([x], [y], label=label, color=mapper[label])
                ax.legend(loc=legend_loc)
        else:
            values = list(node_v.values())
            vmin, vmax = np.min(values), np.max(values)
            nodes_kwargs = dict(  # noqa
                cmap=cont_cmap,
                node_color=values,
                vmin=vmin,
                vmax=vmax)

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="1.5%", pad=0.05)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          cmap=cont_cmap,
                                          norm=mpl.colors.Normalize(vmin=vmin,
                                                                    vmax=vmax))

        nx.draw_networkx_nodes(G,
                               pos,
                               node_size=node_size,
                               ax=ax,
                               **nodes_kwargs)

        ax.set_title(key)
        ax.axis("off")

    if save is not None:
        save_fig(fig, save)

    fig.show()
コード例 #13
0
    def _read_from_adata(
        self, g2m_key: Optional[str] = None, s_key: Optional[str] = None, **kwargs
    ) -> None:
        if f"eig_{self._direction}" in self._adata.uns.keys():
            self._eig = self._adata.uns[f"eig_{self._direction}"]
        else:
            logg.debug(
                f"DEBUG: `eig_{self._direction}` not found. Setting `.eig` to `None`"
            )

        if self._rc_key in self._adata.obs.keys():
            self._meta_states = self._adata.obs[self._rc_key]
        else:
            logg.debug(
                f"DEBUG: `{self._rc_key}` not found in `adata.obs`. Setting `.metastable_states` to `None`"
            )

        if _colors(self._rc_key) in self._adata.uns.keys():
            self._meta_states_colors = self._adata.uns[_colors(self._rc_key)]
        else:
            logg.debug(
                f"DEBUG: `{_colors(self._rc_key)}` not found in `adata.uns`. "
                f"Setting `.metastable_states_colors`to `None`"
            )

        if self._lin_key in self._adata.obsm.keys():
            lineages = range(self._adata.obsm[self._lin_key].shape[1])
            colors = _create_categorical_colors(len(lineages))
            self._lin_probs = Lineage(
                self._adata.obsm[self._lin_key],
                names=[f"Lineage {i + 1}" for i in lineages],
                colors=colors,
            )
            self._adata.obsm[self._lin_key] = self._lin_probs
        else:
            logg.debug(
                f"DEBUG: `{self._lin_key}` not found in `adata.obsm`. Setting `.lin_probs` to `None`"
            )

        if _dp(self._lin_key) in self._adata.obs.keys():
            self._dp = self._adata.obs[_dp(self._lin_key)]
        else:
            logg.debug(
                f"DEBUG: `{_dp(self._lin_key)}` not found in `adata.obs`. Setting `.diff_potential` to `None`"
            )

        if g2m_key and g2m_key in self._adata.obs.keys():
            self._G2M_score = self._adata.obs[g2m_key]
        else:
            logg.debug(
                f"DEBUG: `{g2m_key}` not found in `adata.obs`. Setting `.G2M_score` to `None`"
            )

        if s_key and s_key in self._adata.obs.keys():
            self._S_score = self._adata.obs[s_key]
        else:
            logg.debug(
                f"DEBUG: `{s_key}` not found in `adata.obs`. Setting `.S_score` to `None`"
            )

        if _probs(self._rc_key) in self._adata.obs.keys():
            self._meta_states_probs = self._adata.obs[_probs(self._rc_key)]
        else:
            logg.debug(
                f"DEBUG: `{_probs(self._rc_key)}` not found in `adata.obs`. "
                f"Setting `.metastable_states_probs` to `None`"
            )

        if self._lin_probs is not None:
            if _lin_names(self._lin_key) in self._adata.uns.keys():
                self._lin_probs = Lineage(
                    np.array(self._lin_probs),
                    names=self._adata.uns[_lin_names(self._lin_key)],
                    colors=self._lin_probs.colors,
                )
                self._adata.obsm[self._lin_key] = self._lin_probs
            else:
                logg.debug(
                    f"DEBUG: `{_lin_names(self._lin_key)}` not found in `adata.uns`. "
                    f"Using default names"
                )

            if _colors(self._lin_key) in self._adata.uns.keys():
                self._lin_probs = Lineage(
                    np.array(self._lin_probs),
                    names=self._lin_probs.names,
                    colors=self._adata.uns[_colors(self._lin_key)],
                )
                self._adata.obsm[self._lin_key] = self._lin_probs
            else:
                logg.debug(
                    f"DEBUG: `{_colors(self._lin_key)}` not found in `adata.uns`. "
                    f"Using default colors"
                )
コード例 #14
0
    def compute_lin_probs(
        self,
        keys: Optional[Sequence[str]] = None,
        check_irred: bool = False,
        norm_by_frequ: bool = False,
    ) -> None:
        """
        Compute absorption probabilities for a Markov chain.

        For each cell, this computes the probability of it reaching any of the approximate recurrent classes.
        This also computes the entropy over absorption probabilities, which is a measure of cell plasticity, see
        [Setty19]_.

        Params
        ------
        keys
            Comma separated sequence of keys defining the recurrent classes.
        check_irred
            Check whether the matrix restricted to the given transient states is irreducible.
        norm_by_frequ
            Divide absorption probabilities for `rc_i` by `|rc_i|`.

        Returns
        -------
        None
            Nothing, but updates the following fields: :paramref:`lineage_probabilities`, :paramref:`diff_potential`.
        """

        if self._meta_states is None:
            raise RuntimeError(
                "Compute approximate recurrent classes first as `.compute_metastable_states()`"
            )
        if keys is not None:
            keys = sorted(set(keys))

        # Note: There are three relevant data structures here
        # - self.metastable_states: pd.Series which contains annotations for approx rcs. Associated colors in
        #   self.metastable_states_colors
        # - self.lin_probs: Linage object which contains the lineage probabilities with associated names and colors
        # -_metastable_states: pd.Series, temporary copy of self.approx rcs used in the context of this function.
        #   In this copy, some metastable_states may be removed or combined with others
        start = logg.info("Computing absorption probabilities")

        # we don't expect the abs. probs. to be sparse, therefore, make T dense. See scipy docs about sparse lin solve.
        t = self._T.A if self._is_sparse else self._T

        # colors are created in `compute_metastable_states`, this is just in case
        self._check_and_create_colors()

        # process the current annotations according to `keys`
        metastable_states_, colors_ = _process_series(
            series=self._meta_states, keys=keys, colors=self._meta_states_colors
        )

        #  create empty lineage object
        if self._lin_probs is not None:
            logg.debug("DEBUG: Overwriting `.lin_probs`")
        self._lin_probs = Lineage(
            np.empty((1, len(colors_))),
            names=metastable_states_.cat.categories,
            colors=colors_,
        )

        # warn in case only one state is left
        keys = list(metastable_states_.cat.categories)
        if len(keys) == 1:
            logg.warning(
                "There is only one recurrent class, all cells will have probability 1 of going there"
            )

        # create arrays of all recurrent and transient indices
        mask = np.repeat(False, len(metastable_states_))
        for cat in metastable_states_.cat.categories:
            mask = np.logical_or(mask, metastable_states_ == cat)
        rec_indices, trans_indices = np.where(mask)[0], np.where(~mask)[0]

        # create Q (restriction transient-transient), S (restriction transient-recurrent) and I (Q-sized identity)
        q = t[trans_indices, :][:, trans_indices]
        s = t[trans_indices, :][:, rec_indices]
        eye = np.eye(len(trans_indices))

        if check_irred:
            if self._is_irreducible is None:
                self.compute_partition()
            if not self._is_irreducible:
                logg.warning("Restriction Q is not irreducible")

        # compute abs probs. Since we don't expect sparse solution, dense computation is faster.
        logg.debug("DEBUG: Solving the linear system to find absorption probabilities")
        abs_states = solve(eye - q, s)

        # aggregate to class level by summing over columns belonging to the same metastable_states
        approx_rc_red = metastable_states_[mask]
        rec_classes_red = {
            key: np.where(approx_rc_red == key)[0]
            for key in approx_rc_red.cat.categories
        }
        _abs_classes = np.concatenate(
            [
                np.sum(abs_states[:, rec_classes_red[key]], axis=1)[:, None]
                for key in approx_rc_red.cat.categories
            ],
            axis=1,
        )

        if norm_by_frequ:
            logg.debug("DEBUG: Normalizing by frequency")
            _abs_classes /= [len(value) for value in rec_classes_red.values()]
        _abs_classes = _normalize(_abs_classes)

        # for recurrent states, set their self-absorption probability to one
        abs_classes = np.zeros((self._n_states, len(rec_classes_red)))
        rec_classes_full = {
            cl: np.where(metastable_states_ == cl)
            for cl in metastable_states_.cat.categories
        }
        for col, cl_indices in enumerate(rec_classes_full.values()):
            abs_classes[trans_indices, col] = _abs_classes[:, col]
            abs_classes[cl_indices, col] = 1

        self._dp = entropy(abs_classes.T)
        self._lin_probs = Lineage(
            abs_classes,
            names=list(self._lin_probs.names),
            colors=list(self._lin_probs.colors),
        )

        self._adata.obsm[self._lin_key] = self._lin_probs
        self._adata.obs[_dp(self._lin_key)] = self._dp
        self._adata.uns[_lin_names(self._lin_key)] = self._lin_probs.names
        self._adata.uns[_colors(self._lin_key)] = self._lin_probs.colors

        logg.info("    Finish", time=start)
コード例 #15
0
ファイル: _base_estimator.py プロジェクト: dpeerlab/cellrank
    def _set_categorical_labels(
        self,
        attr_key: str,
        pretty_attr_key: str,
        cat_key: str,
        add_to_existing_error_msg: str,
        categories: Union[Series, Dict[Any, Any]],
        cluster_key: Optional[str] = None,
        en_cutoff: Optional[float] = None,
        p_thresh: Optional[float] = None,
        add_to_existing: bool = False,
    ):
        if isinstance(categories, dict):
            categories = _convert_to_categorical_series(
                categories, list(self.adata.obs_names))
        if not is_categorical_dtype(categories):
            raise TypeError(
                f"Object must be `categorical`, found `{infer_dtype(categories)}`."
            )

        if add_to_existing:
            if getattr(self, attr_key) is None:
                raise RuntimeError(add_to_existing_error_msg)
            categories = _merge_categorical_series(getattr(self, attr_key),
                                                   categories,
                                                   inplace=False)

        if cluster_key is not None:
            logg.debug(f"DEBUG: Creating colors based on `{cluster_key}`")

            # check that we can load the reference series from adata
            if cluster_key not in self._adata.obs:
                raise KeyError(
                    f"Cluster key `{cluster_key!r}` not found in `adata.obs`.")
            series_query, series_reference = categories, self._adata.obs[
                cluster_key]

            # load the reference colors if they exist
            if _colors(cluster_key) in self._adata.uns.keys():
                colors_reference = _convert_to_hex_colors(
                    self._adata.uns[_colors(cluster_key)])
            else:
                colors_reference = _create_categorical_colors(
                    len(series_reference.cat.categories))

            approx_rcs_names, colors = _map_names_and_colors(
                series_reference=series_reference,
                series_query=series_query,
                colors_reference=colors_reference,
                en_cutoff=en_cutoff,
            )
            setattr(self, f"{attr_key}_colors", colors)
            categories.cat.categories = approx_rcs_names
        else:
            setattr(
                self,
                f"{attr_key}_colors",
                _create_categorical_colors(len(categories.cat.categories)),
            )

        if p_thresh is not None:
            self._detect_cc_stages(categories, p_thresh=p_thresh)

        # write to class and adata
        if getattr(self, attr_key) is not None:
            logg.debug(f"DEBUG: Overwriting `.{pretty_attr_key}`")

        setattr(self, attr_key, categories)
        self._adata.obs[cat_key] = categories.values
        self._adata.uns[_colors(cat_key)] = getattr(self, f"{attr_key}_colors")