コード例 #1
0
ファイル: _graph.py プロジェクト: sophial05/squidpy
def nhood_enrichment(
    adata: AnnData,
    cluster_key: str,
    mode: Literal["zscore", "count"] = "zscore",
    annotate: bool = False,
    method: Optional[str] = None,
    title: Optional[str] = None,
    cmap: str = "viridis",
    palette: Palette_t = None,
    cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs: Any,
) -> None:
    """
    Plot neighborhood enrichment.

    The enrichment is computed by :func:`squidpy.gr.nhood_enrichment`.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    mode
        Which :func:`squidpy.gr.nhood_enrichment` result to plot. Valid options are:

            - `'zscore'` - z-score values of enrichment statistic.
            - `'count'` - enrichment count.

    %(heatmap_plotting)s
    kwargs
        Keyword arguments for :func:`matplotlib.pyplot.text`.

    Returns
    -------
    %(plotting_returns)s
    """
    _assert_categorical_obs(adata, key=cluster_key)
    array = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment")[mode]

    ad = AnnData(X=array, obs={cluster_key: pd.Categorical(adata.obs[cluster_key].cat.categories)})
    _maybe_set_colors(source=adata, target=ad, key=cluster_key, palette=palette)
    if title is None:
        title = "Neighborhood enrichment"
    fig = _heatmap(
        ad,
        key=cluster_key,
        title=title,
        method=method,
        cont_cmap=cmap,
        annotate=annotate,
        figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize,
        dpi=dpi,
        cbar_kwargs=cbar_kwargs,
        **kwargs,
    )

    if save is not None:
        save_fig(fig, path=save)
コード例 #2
0
ファイル: _graph.py プロジェクト: sophial05/squidpy
def ripley_k(
    adata: AnnData,
    cluster_key: str,
    palette: Palette_t = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    legend_kwargs: Mapping[str, Any] = MappingProxyType({}),
    **kwargs: Any,
) -> None:
    """
    Plot Ripley's K estimate for each cluster.

    The estimate is computed by :func:`squidpy.gr.ripley_k`.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(cat_plotting)s
    legend_kwargs
        Keyword arguments for :func:`matplotlib.pyplot.legend`.
    kwargs
        Keyword arguments for :func:`seaborn.lineplot`.

    Returns
    -------
    %(plotting_returns)s
    """
    _assert_categorical_obs(adata, key=cluster_key)
    df = _get_data(adata, cluster_key=cluster_key, func_name="ripley_k")

    legend_kwargs = dict(legend_kwargs)
    if "loc" not in legend_kwargs:
        legend_kwargs["loc"] = "center left"
        legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5))

    categories = adata.obs[cluster_key].cat.categories
    palette = _get_palette(adata, cluster_key=cluster_key, categories=categories) if palette is None else palette

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    sns.lineplot(
        x="distance",
        y="ripley_k",
        hue=cluster_key,
        hue_order=categories,
        data=df,
        palette=palette,
        ax=ax,
        **kwargs,
    )
    ax.legend(**legend_kwargs)
    ax.set_ylabel("value")
    ax.set_title("Ripley's K")

    if save is not None:
        save_fig(fig, path=save)
コード例 #3
0
ファイル: _graph.py プロジェクト: sophial05/squidpy
def interaction_matrix(
    adata: AnnData,
    cluster_key: str,
    annotate: bool = False,
    method: Optional[str] = None,
    title: Optional[str] = None,
    cmap: str = "viridis",
    palette: Palette_t = None,
    cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs: Any,
) -> None:
    """
    Plot cluster interaction matrix.

    The interaction matrix is computed by :func:`squidpy.gr.interaction_matrix`.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(heatmap_plotting)s
    kwargs
        Keyword arguments for :func:`matplotlib.pyplot.text`.

    Returns
    -------
    %(plotting_returns)s
    """
    _assert_categorical_obs(adata, key=cluster_key)
    array = _get_data(adata, cluster_key=cluster_key, func_name="interaction_matrix")

    ad = AnnData(X=array, obs={cluster_key: pd.Categorical(adata.obs[cluster_key].cat.categories)})
    _maybe_set_colors(source=adata, target=ad, key=cluster_key, palette=palette)
    if title is None:
        title = "Interaction matrix"
    fig = _heatmap(
        ad,
        key=cluster_key,
        title=title,
        method=method,
        cont_cmap=cmap,
        annotate=annotate,
        figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize,
        dpi=dpi,
        cbar_kwargs=cbar_kwargs,
        **kwargs,
    )

    if save is not None:
        save_fig(fig, path=save)
コード例 #4
0
ファイル: _nhood.py プロジェクト: michalk8/squidpy
def interaction_matrix(
    adata: AnnData,
    cluster_key: str,
    connectivity_key: Optional[str] = None,
    normalized: bool = False,
    copy: bool = False,
) -> Optional[np.ndarray]:
    """
    Compute interaction matrix for clusters.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(conn_key)s
    normalized
        If `True`, each row is normalized to sum to 1.
    %(copy)s

    Returns
    -------
    If ``copy = True``, returns the interaction matrix.

    Otherwise, modifies the ``adata`` with the following key:

        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix.
    """
    connectivity_key = Key.obsp.spatial_conn(connectivity_key)
    _assert_categorical_obs(adata, cluster_key)
    _assert_connectivity_key(adata, connectivity_key)

    graph = nx.from_scipy_sparse_matrix(adata.obsp[connectivity_key])
    cluster = {
        i: {
            cluster_key: x
        }
        for i, x in enumerate(adata.obs[cluster_key])
    }

    nx.set_node_attributes(graph, cluster)
    int_mat = np.asarray(
        nx.attr_matrix(graph,
                       node_attr=cluster_key,
                       normalized=normalized,
                       rc_order=adata.obs[cluster_key].cat.categories))

    if copy:
        return int_mat
    _save_data(adata,
               attr="uns",
               key=Key.uns.interaction_matrix(cluster_key),
               data=int_mat)
コード例 #5
0
ファイル: _nhood.py プロジェクト: michalk8/squidpy
def centrality_scores(
    adata: AnnData,
    cluster_key: str,
    score: Optional[Union[str, Iterable[str]]] = None,
    connectivity_key: Optional[str] = None,
    copy: bool = False,
    n_jobs: Optional[int] = None,
    backend: str = "loky",
    show_progress_bar: bool = False,
) -> Optional[pd.DataFrame]:
    """
    Compute centrality scores per cluster or cell type.

    Inspired by usage in Gene Regulatory Networks (GRNs) in :cite:`celloracle`.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    score
        Centrality measures as described in :class:`networkx.algorithms.centrality` :cite:`networkx`.
        If `None`, use all the options below. Valid options are:

            - `{c.CLOSENESS.s!r}` - measure of how close the group is to other nodes.
            - `{c.CLUSTERING.s!r}` - measure of the degree to which nodes cluster together.
            - `{c.DEGREE.s!r}` - fraction of non-group members connected to group members.

    %(conn_key)s
    %(copy)s
    %(parallelize)s

    Returns
    -------
    If ``copy = True``, returns a :class:`pandas.DataFrame`. Otherwise, modifies the ``adata`` with the following key:

        - :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_centrality_scores']`` - the centrality scores,
          as mentioned above.
    """
    connectivity_key = Key.obsp.spatial_conn(connectivity_key)
    _assert_categorical_obs(adata, cluster_key)
    _assert_connectivity_key(adata, connectivity_key)

    if isinstance(score, (str, Centrality)):
        centrality = [score]
    elif score is None:
        centrality = [c.s for c in Centrality]

    centralities = [Centrality(c) for c in centrality]

    graph = nx.from_scipy_sparse_matrix(adata.obsp[connectivity_key])

    cat = adata.obs[cluster_key].cat.categories.values
    clusters = adata.obs[cluster_key].values

    fun_dict = {}
    for c in centralities:
        if c == Centrality.CLOSENESS:
            fun_dict[c.s] = partial(
                nx.algorithms.centrality.group_closeness_centrality, graph)
        elif c == Centrality.DEGREE:
            fun_dict[c.s] = partial(
                nx.algorithms.centrality.group_degree_centrality, graph)
        elif c == Centrality.CLUSTERING:
            fun_dict[c.s] = partial(nx.algorithms.cluster.average_clustering,
                                    graph)
        else:
            raise NotImplementedError(
                f"Centrality `{c}` is not yet implemented.")

    n_jobs = _get_n_cores(n_jobs)
    start = logg.info(
        f"Calculating centralities `{centralities}` using `{n_jobs}` core(s)")

    res_list = []
    for k, v in fun_dict.items():
        df = parallelize(
            _centrality_scores_helper,
            collection=cat,
            extractor=pd.concat,
            n_jobs=n_jobs,
            backend=backend,
            show_progress_bar=show_progress_bar,
        )(clusters=clusters, fun=v, method=k)
        res_list.append(df)

    df = pd.concat(res_list, axis=1)

    if copy:
        return df
    _save_data(adata,
               attr="uns",
               key=Key.uns.centrality_scores(cluster_key),
               data=df,
               time=start)
コード例 #6
0
ファイル: _nhood.py プロジェクト: michalk8/squidpy
def nhood_enrichment(
    adata: AnnData,
    cluster_key: str,
    connectivity_key: Optional[str] = None,
    n_perms: int = 1000,
    numba_parallel: bool = False,
    seed: Optional[int] = None,
    copy: bool = False,
    n_jobs: Optional[int] = None,
    backend: str = "loky",
    show_progress_bar: bool = True,
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
    """
    Compute neighborhood enrichment by permutation test.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(conn_key)s
    %(n_perms)s
    %(numba_parallel)s
    %(seed)s
    %(copy)s
    %(parallelize)s

    Returns
    -------
    If ``copy = True``, returns a :class:`tuple` with the z-score and the enrichment count.

    Otherwise, modifies the ``adata`` with the following keys:

        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score.
        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count.
    """
    connectivity_key = Key.obsp.spatial_conn(connectivity_key)
    _assert_categorical_obs(adata, cluster_key)
    _assert_connectivity_key(adata, connectivity_key)
    _assert_positive(n_perms, name="n_perms")

    adj = adata.obsp[connectivity_key]
    original_clust = adata.obs[cluster_key]
    clust_map = {
        v: i
        for i, v in enumerate(original_clust.cat.categories.values)
    }  # map categories
    int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt)

    indices, indptr = (adj.indices.astype(ndt), adj.indptr.astype(ndt))
    n_cls = len(clust_map)

    _test = _create_function(n_cls, parallel=numba_parallel)
    count = _test(indices, indptr, int_clust)

    n_jobs = _get_n_cores(n_jobs)
    start = logg.info(
        f"Calculating neighborhood enrichment using `{n_jobs}` core(s)")

    perms = parallelize(
        _nhood_enrichment_helper,
        collection=np.arange(n_perms),
        extractor=np.vstack,
        n_jobs=n_jobs,
        backend=backend,
        show_progress_bar=show_progress_bar,
    )(callback=_test,
      indices=indices,
      indptr=indptr,
      int_clust=int_clust,
      n_cls=n_cls,
      seed=seed)
    zscore = (count - perms.mean(axis=0)) / perms.std(axis=0)

    if copy:
        return zscore, count

    _save_data(
        adata,
        attr="uns",
        key=Key.uns.nhood_enrichment(cluster_key),
        data={
            "zscore": zscore,
            "count": count
        },
        time=start,
    )
コード例 #7
0
def _heatmap(
    adata: AnnData,
    key: str,
    title: str = "",
    method: Optional[str] = None,
    cont_cmap: Union[str, mcolors.Colormap] = "viridis",
    annotate: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
    **kwargs: Any,
) -> mpl.figure.Figure:
    _assert_categorical_obs(adata, key=key)

    cbar_kwargs = dict(cbar_kwargs)
    fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize)

    if method is not None:
        row_order, col_order, row_link, col_link = _dendrogram(adata.X, method, optimal_ordering=adata.n_obs <= 1500)
    else:
        row_order = col_order = np.arange(len(adata.uns[Key.uns.colors(key)]))

    row_order = row_order[::-1]
    row_labels = adata.obs[key][row_order]
    data = adata[row_order, col_order].X

    row_cmap, col_cmap, row_norm, col_norm, n_cls = _get_cmap_norm(adata, key, order=(row_order, col_order))

    row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm)
    col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm)

    norm = mpl.colors.Normalize(vmin=kwargs.pop("vmin", np.nanmin(data)), vmax=kwargs.pop("vmax", np.nanmax(data)))
    cont_cmap = copy(plt.get_cmap(cont_cmap))
    cont_cmap.set_bad(color="grey")

    im = ax.imshow(data[::-1], cmap=cont_cmap, norm=norm)

    ax.grid(False)
    ax.tick_params(top=False, bottom=False, labeltop=False, labelbottom=False)
    ax.set_xticks([])
    ax.set_yticks([])

    if annotate:
        _annotate_heatmap(im, cmap=cont_cmap, **kwargs)

    divider = make_axes_locatable(ax)
    row_cats = divider.append_axes("left", size="2%", pad=0)
    col_cats = divider.append_axes("top", size="2%", pad=0)
    cax = divider.append_axes("right", size="1%", pad=0.1)
    if method is not None:  # cluster rows but don't plot dendrogram
        col_ax = divider.append_axes("top", size="5%")
        sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black")
        col_ax.axis("off")

    _ = fig.colorbar(
        im,
        cax=cax,
        ticks=np.linspace(norm.vmin, norm.vmax, 10),
        orientation="vertical",
        format="%0.2f",
        **cbar_kwargs,
    )

    # column labels colorbar
    c = fig.colorbar(col_sm, cax=col_cats, orientation="horizontal")
    c.set_ticks([])
    (col_cats if method is None else col_ax).set_title(title)

    # row labels colorbar
    c = fig.colorbar(row_sm, cax=row_cats, orientation="vertical", ticklocation="left")
    c.set_ticks(np.arange(n_cls) + 0.5)
    c.set_ticklabels(row_labels)
    c.set_label(key)

    return fig
コード例 #8
0
ファイル: _nhood.py プロジェクト: sophial05/squidpy
def interaction_matrix(
    adata: AnnData,
    cluster_key: str,
    connectivity_key: Optional[str] = None,
    normalized: bool = False,
    copy: bool = False,
    weights: bool = False,
) -> Optional[np.ndarray]:
    """
    Compute interaction matrix for clusters.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(conn_key)s
    normalized
        If `True`, each row is normalized to sum to 1.
    %(copy)s
    weights
        Whether to use edge weights or binarize.

    Returns
    -------
    If ``copy = True``, returns the interaction matrix.

    Otherwise, modifies the ``adata`` with the following key:

        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix.
    """
    connectivity_key = Key.obsp.spatial_conn(connectivity_key)
    _assert_categorical_obs(adata, cluster_key)
    _assert_connectivity_key(adata, connectivity_key)

    cats = adata.obs[cluster_key]
    mask = ~pd.isnull(cats).values
    cats = cats.loc[mask]
    if not len(cats):
        raise RuntimeError(
            f"After removing NaNs in `adata.obs[{cluster_key!r}]`, none remain."
        )

    g = adata.obsp[connectivity_key]
    g = g[mask, :][:, mask]
    n_cats = len(cats.cat.categories)

    if weights:
        g_data = g.data
    else:
        g_data = np.broadcast_to(1, shape=len(g.data))
    if pd.api.types.is_bool_dtype(g.dtype) or pd.api.types.is_integer_dtype(
            g.dtype):
        dtype = np.intp
    else:
        dtype = np.float_
    output = np.zeros((n_cats, n_cats), dtype=dtype)

    _interaction_matrix(g_data, g.indices, g.indptr, cats.cat.codes.to_numpy(),
                        output)

    if normalized:
        output = output / output.sum(axis=1).reshape((-1, 1))

    if copy:
        return output
    _save_data(adata,
               attr="uns",
               key=Key.uns.interaction_matrix(cluster_key),
               data=output)
コード例 #9
0
ファイル: _ligrec.py プロジェクト: sophial05/squidpy
    def test(
        self,
        cluster_key: str,
        clusters: Optional[Cluster_t] = None,
        n_perms: int = 1000,
        threshold: float = 0.01,
        seed: Optional[int] = None,
        corr_method: Optional[str] = None,
        corr_axis: Union[str, CorrAxis] = CorrAxis.INTERACTIONS.v,
        alpha: float = 0.05,
        copy: bool = False,
        key_added: Optional[str] = None,
        numba_parallel: Optional[bool] = None,
        **kwargs: Any,
    ) -> Optional[Mapping[str, pd.DataFrame]]:
        """
        Perform the permutation test as described in :cite:`cellphonedb`.

        Parameters
        ----------
        %(cluster_key)s
        clusters
            Clusters from :attr:`anndata.AnnData.obs` ``['{{cluster_key}}']``. Can be specified either as a sequence
            of :class:`tuple` or just a sequence of cluster names, in which case all combinations considered.
        %(n_perms)s
        threshold
            Do not perform permutation test if any of the interacting components is being expressed
            in less than ``threshold`` percent of cells within a given cluster.
        %(seed)s
        %(corr_method)s
        corr_axis
            Axis over which to perform the FDR correction. Only used when ``corr_method != None``. Valid options are:

                - `{fa.INTERACTIONS.s!r}` - correct interactions by performing FDR correction across the clusters.
                - `{fa.CLUSTERS.s!r}` - correct clusters by performing FDR correction across the interactions.
        alpha
            Significance level for FDR correction. Only used when ``corr_method != None``.
        %(copy)s
        key_added
            Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``.
            If `None`, ``'{{cluster_key}}_ligrec'`` will be used.
        %(numba_parallel)s
        %(parallelize)s

        Returns
        -------
        %(ligrec_test_returns)s
        """
        _assert_positive(n_perms, name="n_perms")
        _assert_categorical_obs(self._adata, key=cluster_key)

        if corr_method is not None:
            corr_axis = CorrAxis(corr_axis)
        if TYPE_CHECKING:
            assert isinstance(corr_axis, CorrAxis)

        if len(self._adata.obs[cluster_key].cat.categories) <= 1:
            raise ValueError(
                f"Expected at least `2` clusters, found `{len(self._adata.obs[cluster_key].cat.categories)}`."
            )
        if TYPE_CHECKING:
            assert isinstance(self.interactions, pd.DataFrame)
            assert isinstance(self._filtered_data, pd.DataFrame)

        interactions = self.interactions[[SOURCE, TARGET]]
        self._filtered_data["clusters"] = self._adata.obs[cluster_key].astype(
            "string").astype("category").values

        if clusters is None:
            clusters = list(
                map(str, self._adata.obs[cluster_key].cat.categories))
        if all(isinstance(c, str) for c in clusters):
            clusters = list(product(
                clusters, repeat=2))  # type: ignore[no-redef,assignment]
        clusters = sorted(
            _check_tuple_needles(
                clusters,  # type: ignore[arg-type]
                self._filtered_data["clusters"].cat.categories,
                msg="Invalid cluster `{0!r}`.",
                reraise=True,
            ))
        clusters_flat = list({c for cs in clusters for c in cs})

        data = self._filtered_data.loc[
            np.isin(self._filtered_data["clusters"], clusters_flat), :]
        data["clusters"] = data["clusters"].cat.remove_unused_categories()
        cat = data["clusters"].cat

        cluster_mapper = dict(zip(cat.categories, range(len(cat.categories))))
        gene_mapper = dict(zip(data.columns[:-1],
                               range(len(data.columns) -
                                     1)))  # -1 for 'clusters'

        data.columns = [
            gene_mapper[c] if c != "clusters" else c for c in data.columns
        ]
        clusters_ = np.array([[cluster_mapper[c1], cluster_mapper[c2]]
                              for c1, c2 in clusters],
                             dtype=np.uint32)

        cat.rename_categories(cluster_mapper, inplace=True)
        # much faster than applymap (tested on 1M interactions)
        interactions_ = np.vectorize(lambda g: gene_mapper[g])(
            interactions.values)

        n_jobs = _get_n_cores(kwargs.pop("n_jobs", None))
        start = logg.info(
            f"Running `{n_perms}` permutations on `{len(interactions)}` interactions "
            f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)"
        )
        res = _analysis(
            data,
            interactions_,
            clusters_,
            threshold=threshold,
            n_perms=n_perms,
            seed=seed,
            n_jobs=n_jobs,
            numba_parallel=numba_parallel,
            **kwargs,
        )

        res = {
            "means":
            _create_sparse_df(
                res.means,
                index=pd.MultiIndex.from_frame(interactions,
                                               names=[SOURCE, TARGET]),
                columns=pd.MultiIndex.from_tuples(
                    clusters, names=["cluster_1", "cluster_2"]),
                fill_value=0,
            ),
            "pvalues":
            _create_sparse_df(
                res.pvalues,
                index=pd.MultiIndex.from_frame(interactions,
                                               names=[SOURCE, TARGET]),
                columns=pd.MultiIndex.from_tuples(
                    clusters, names=["cluster_1", "cluster_2"]),
                fill_value=np.nan,
            ),
            "metadata":
            self.interactions[self.interactions.columns.difference(
                [SOURCE, TARGET])],
        }
        res["metadata"].index = res["means"].index.copy()

        if TYPE_CHECKING:
            assert isinstance(res, dict)

        if corr_method is not None:
            logg.info(f"Performing FDR correction across the `{corr_axis.v}` "
                      f"using method `{corr_method}` at level `{alpha}`")
            res["pvalues"] = _fdr_correct(res["pvalues"],
                                          corr_method,
                                          corr_axis,
                                          alpha=alpha)

        if copy:
            logg.info("Finish", time=start)
            return res

        _save_data(self._adata,
                   attr="uns",
                   key=Key.uns.ligrec(cluster_key, key_added),
                   data=res,
                   time=start)
コード例 #10
0
ファイル: _ppatterns.py プロジェクト: sophial05/squidpy
def co_occurrence(
    adata: AnnData,
    cluster_key: str,
    spatial_key: str = Key.obsm.spatial,
    n_steps: int = 50,
    copy: bool = False,
    n_splits: Optional[int] = None,
    n_jobs: Optional[int] = None,
    backend: str = "loky",
    show_progress_bar: bool = True,
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
    """
    Compute co-occurrence probability of clusters.

    The co-occurrence is computed across ``n_steps`` distance thresholds in spatial dimensions.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(spatial_key)s
    n_steps
        Number of distance thresholds at which co-occurrence is computed.

    %(copy)s
    n_splits
        Number of splits in which to divide the spatial coordinates in
        :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``.
    %(parallelize)s

    Returns
    -------
    If ``copy = True``, returns the co-occurrence probability and the distance thresholds intervals.

    Otherwise, modifies the ``adata`` with the following keys:

        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['occ']`` - the co-occurrence probabilities
          across interval thresholds.
        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds
          computed at ``n_steps``.
    """
    _assert_categorical_obs(adata, key=cluster_key)
    _assert_spatial_basis(adata, key=spatial_key)

    spatial = adata.obsm[spatial_key].astype(fp)
    original_clust = adata.obs[cluster_key]

    # find minimum, second minimum and maximum for thresholding
    thres_min, thres_max = _find_min_max(spatial)

    # annotate cluster idx
    clust_map = {
        v: i
        for i, v in enumerate(original_clust.cat.categories.values)
    }
    labs = np.array([clust_map[c] for c in original_clust], dtype=ip)

    labs_unique = np.array(list(clust_map.values()), dtype=ip)

    # create intervals thresholds
    interval = np.linspace(thres_min, thres_max, num=n_steps, dtype=fp)

    n_obs = spatial.shape[0]
    if n_splits is None:
        size_arr = (n_obs**2 * 4) / 1024 / 1024  # calc expected mem usage
        if size_arr > 2_000:
            s = 1
            while 2_048 < (n_obs / s):
                s += 1
            n_splits = s
            logg.warning(
                f"`n_splits` was automatically set to: {n_splits}\n"
                f"preventing a NxN with N={n_obs} distance matrix to be created"
            )
        else:
            n_splits = 1
コード例 #11
0
ファイル: _graph.py プロジェクト: sophial05/squidpy
def centrality_scores(
    adata: AnnData,
    cluster_key: str,
    score: Optional[Union[str, Sequence[str]]] = None,
    legend_kwargs: Mapping[str, Any] = MappingProxyType({}),
    palette: Palette_t = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs: Any,
) -> None:
    """
    Plot centrality scores.

    The centrality scores are computed by :func:`squidpy.gr.centrality_scores`.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    score
        Whether to plot all scores or only selected ones.
    legend_kwargs
        Keyword arguments for :func:`matplotlib.pyplot.legend`.
    %(cat_plotting)s

    Returns
    -------
    %(plotting_returns)s
    """
    _assert_categorical_obs(adata, key=cluster_key)
    df = _get_data(adata, cluster_key=cluster_key, func_name="centrality_scores").copy()

    legend_kwargs = dict(legend_kwargs)
    if "loc" not in legend_kwargs:
        legend_kwargs["loc"] = "center left"
        legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5))

    scores = df.columns.values
    df[cluster_key] = df.index.values

    clusters = adata.obs[cluster_key].cat.categories
    palette = _get_palette(adata, cluster_key=cluster_key, categories=clusters) if palette is None else palette

    score = scores if score is None else score
    score = _assert_non_empty_sequence(score, name="centrality scores")
    score = sorted(_get_valid_values(score, scores))

    fig, axs = plt.subplots(1, len(score), figsize=figsize, dpi=dpi, constrained_layout=True)
    axs = np.ravel(axs)  # make into iterable
    for g, ax in zip(score, axs):
        sns.scatterplot(
            x=g,
            y=cluster_key,
            data=df,
            hue=cluster_key,
            hue_order=clusters,
            palette=palette,
            ax=ax,
            **kwargs,
        )
        ax.set_title(str(g).replace("_", " ").capitalize())
        ax.set_xlabel("value")

        ax.set_yticks([])
        ax.legend(**legend_kwargs)

    if save is not None:
        save_fig(fig, path=save)
コード例 #12
0
ファイル: _graph.py プロジェクト: sophial05/squidpy
def co_occurrence(
    adata: AnnData,
    cluster_key: str,
    palette: Palette_t = None,
    clusters: Optional[Union[str, Sequence[str]]] = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    legend_kwargs: Mapping[str, Any] = MappingProxyType({}),
    **kwargs: Any,
) -> None:
    """
    Plot co-occurrence probability ratio for each cluster.

    The co-occurrence is computed by :func:`squidpy.gr.co_occurrence`.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    clusters
        Cluster instances for which to plot conditional probability.
    %(cat_plotting)s
    legend_kwargs
        Keyword arguments for :func:`matplotlib.pyplot.legend`.
    kwargs
        Keyword arguments for :func:`seaborn.lineplot`.

    Returns
    -------
    %(plotting_returns)s
    """
    _assert_categorical_obs(adata, key=cluster_key)
    occurrence_data = _get_data(adata, cluster_key=cluster_key, func_name="co_occurrence")

    legend_kwargs = dict(legend_kwargs)
    if "loc" not in legend_kwargs:
        legend_kwargs["loc"] = "center left"
        legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5))

    out = occurrence_data["occ"]
    interval = occurrence_data["interval"][1:]
    categories = adata.obs[cluster_key].cat.categories

    clusters = categories if clusters is None else clusters
    clusters = _assert_non_empty_sequence(clusters, name="clusters")
    clusters = sorted(_get_valid_values(clusters, categories))

    palette = _get_palette(adata, cluster_key=cluster_key, categories=categories) if palette is None else palette

    fig, axs = plt.subplots(
        1,
        len(clusters),
        figsize=(5 * len(clusters), 5) if figsize is None else figsize,
        dpi=dpi,
        constrained_layout=True,
    )
    axs = np.ravel(axs)  # make into iterable

    for g, ax in zip(clusters, axs):
        idx = np.where(categories == g)[0][0]
        df = pd.DataFrame(out[idx, :, :].T, columns=categories).melt(var_name=cluster_key, value_name="probability")
        df["distance"] = np.tile(interval, len(categories))

        sns.lineplot(
            x="distance",
            y="probability",
            data=df,
            dashes=False,
            hue=cluster_key,
            hue_order=categories,
            palette=palette,
            ax=ax,
            **kwargs,
        )
        ax.legend(**legend_kwargs)
        ax.set_title(rf"$\frac{{p(exp|{g})}}{{p(exp)}}$")
        ax.set_ylabel("value")

    if save is not None:
        save_fig(fig, path=save)