def nhood_enrichment( adata: AnnData, cluster_key: str, mode: Literal["zscore", "count"] = "zscore", annotate: bool = False, method: Optional[str] = None, title: Optional[str] = None, cmap: str = "viridis", palette: Palette_t = None, cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs: Any, ) -> None: """ Plot neighborhood enrichment. The enrichment is computed by :func:`squidpy.gr.nhood_enrichment`. Parameters ---------- %(adata)s %(cluster_key)s mode Which :func:`squidpy.gr.nhood_enrichment` result to plot. Valid options are: - `'zscore'` - z-score values of enrichment statistic. - `'count'` - enrichment count. %(heatmap_plotting)s kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) array = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment")[mode] ad = AnnData(X=array, obs={cluster_key: pd.Categorical(adata.obs[cluster_key].cat.categories)}) _maybe_set_colors(source=adata, target=ad, key=cluster_key, palette=palette) if title is None: title = "Neighborhood enrichment" fig = _heatmap( ad, key=cluster_key, title=title, method=method, cont_cmap=cmap, annotate=annotate, figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize, dpi=dpi, cbar_kwargs=cbar_kwargs, **kwargs, ) if save is not None: save_fig(fig, path=save)
def ripley_k( adata: AnnData, cluster_key: str, palette: Palette_t = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> None: """ Plot Ripley's K estimate for each cluster. The estimate is computed by :func:`squidpy.gr.ripley_k`. Parameters ---------- %(adata)s %(cluster_key)s %(cat_plotting)s legend_kwargs Keyword arguments for :func:`matplotlib.pyplot.legend`. kwargs Keyword arguments for :func:`seaborn.lineplot`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) df = _get_data(adata, cluster_key=cluster_key, func_name="ripley_k") legend_kwargs = dict(legend_kwargs) if "loc" not in legend_kwargs: legend_kwargs["loc"] = "center left" legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5)) categories = adata.obs[cluster_key].cat.categories palette = _get_palette(adata, cluster_key=cluster_key, categories=categories) if palette is None else palette fig, ax = plt.subplots(figsize=figsize, dpi=dpi) sns.lineplot( x="distance", y="ripley_k", hue=cluster_key, hue_order=categories, data=df, palette=palette, ax=ax, **kwargs, ) ax.legend(**legend_kwargs) ax.set_ylabel("value") ax.set_title("Ripley's K") if save is not None: save_fig(fig, path=save)
def interaction_matrix( adata: AnnData, cluster_key: str, annotate: bool = False, method: Optional[str] = None, title: Optional[str] = None, cmap: str = "viridis", palette: Palette_t = None, cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs: Any, ) -> None: """ Plot cluster interaction matrix. The interaction matrix is computed by :func:`squidpy.gr.interaction_matrix`. Parameters ---------- %(adata)s %(cluster_key)s %(heatmap_plotting)s kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) array = _get_data(adata, cluster_key=cluster_key, func_name="interaction_matrix") ad = AnnData(X=array, obs={cluster_key: pd.Categorical(adata.obs[cluster_key].cat.categories)}) _maybe_set_colors(source=adata, target=ad, key=cluster_key, palette=palette) if title is None: title = "Interaction matrix" fig = _heatmap( ad, key=cluster_key, title=title, method=method, cont_cmap=cmap, annotate=annotate, figsize=(2 * ad.n_obs // 3, 2 * ad.n_obs // 3) if figsize is None else figsize, dpi=dpi, cbar_kwargs=cbar_kwargs, **kwargs, ) if save is not None: save_fig(fig, path=save)
def interaction_matrix( adata: AnnData, cluster_key: str, connectivity_key: Optional[str] = None, normalized: bool = False, copy: bool = False, ) -> Optional[np.ndarray]: """ Compute interaction matrix for clusters. Parameters ---------- %(adata)s %(cluster_key)s %(conn_key)s normalized If `True`, each row is normalized to sum to 1. %(copy)s Returns ------- If ``copy = True``, returns the interaction matrix. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix. """ connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) graph = nx.from_scipy_sparse_matrix(adata.obsp[connectivity_key]) cluster = { i: { cluster_key: x } for i, x in enumerate(adata.obs[cluster_key]) } nx.set_node_attributes(graph, cluster) int_mat = np.asarray( nx.attr_matrix(graph, node_attr=cluster_key, normalized=normalized, rc_order=adata.obs[cluster_key].cat.categories)) if copy: return int_mat _save_data(adata, attr="uns", key=Key.uns.interaction_matrix(cluster_key), data=int_mat)
def centrality_scores( adata: AnnData, cluster_key: str, score: Optional[Union[str, Iterable[str]]] = None, connectivity_key: Optional[str] = None, copy: bool = False, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = False, ) -> Optional[pd.DataFrame]: """ Compute centrality scores per cluster or cell type. Inspired by usage in Gene Regulatory Networks (GRNs) in :cite:`celloracle`. Parameters ---------- %(adata)s %(cluster_key)s score Centrality measures as described in :class:`networkx.algorithms.centrality` :cite:`networkx`. If `None`, use all the options below. Valid options are: - `{c.CLOSENESS.s!r}` - measure of how close the group is to other nodes. - `{c.CLUSTERING.s!r}` - measure of the degree to which nodes cluster together. - `{c.DEGREE.s!r}` - fraction of non-group members connected to group members. %(conn_key)s %(copy)s %(parallelize)s Returns ------- If ``copy = True``, returns a :class:`pandas.DataFrame`. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_centrality_scores']`` - the centrality scores, as mentioned above. """ connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) if isinstance(score, (str, Centrality)): centrality = [score] elif score is None: centrality = [c.s for c in Centrality] centralities = [Centrality(c) for c in centrality] graph = nx.from_scipy_sparse_matrix(adata.obsp[connectivity_key]) cat = adata.obs[cluster_key].cat.categories.values clusters = adata.obs[cluster_key].values fun_dict = {} for c in centralities: if c == Centrality.CLOSENESS: fun_dict[c.s] = partial( nx.algorithms.centrality.group_closeness_centrality, graph) elif c == Centrality.DEGREE: fun_dict[c.s] = partial( nx.algorithms.centrality.group_degree_centrality, graph) elif c == Centrality.CLUSTERING: fun_dict[c.s] = partial(nx.algorithms.cluster.average_clustering, graph) else: raise NotImplementedError( f"Centrality `{c}` is not yet implemented.") n_jobs = _get_n_cores(n_jobs) start = logg.info( f"Calculating centralities `{centralities}` using `{n_jobs}` core(s)") res_list = [] for k, v in fun_dict.items(): df = parallelize( _centrality_scores_helper, collection=cat, extractor=pd.concat, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )(clusters=clusters, fun=v, method=k) res_list.append(df) df = pd.concat(res_list, axis=1) if copy: return df _save_data(adata, attr="uns", key=Key.uns.centrality_scores(cluster_key), data=df, time=start)
def nhood_enrichment( adata: AnnData, cluster_key: str, connectivity_key: Optional[str] = None, n_perms: int = 1000, numba_parallel: bool = False, seed: Optional[int] = None, copy: bool = False, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = True, ) -> Optional[Tuple[np.ndarray, np.ndarray]]: """ Compute neighborhood enrichment by permutation test. Parameters ---------- %(adata)s %(cluster_key)s %(conn_key)s %(n_perms)s %(numba_parallel)s %(seed)s %(copy)s %(parallelize)s Returns ------- If ``copy = True``, returns a :class:`tuple` with the z-score and the enrichment count. Otherwise, modifies the ``adata`` with the following keys: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score. - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count. """ connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) _assert_positive(n_perms, name="n_perms") adj = adata.obsp[connectivity_key] original_clust = adata.obs[cluster_key] clust_map = { v: i for i, v in enumerate(original_clust.cat.categories.values) } # map categories int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt) indices, indptr = (adj.indices.astype(ndt), adj.indptr.astype(ndt)) n_cls = len(clust_map) _test = _create_function(n_cls, parallel=numba_parallel) count = _test(indices, indptr, int_clust) n_jobs = _get_n_cores(n_jobs) start = logg.info( f"Calculating neighborhood enrichment using `{n_jobs}` core(s)") perms = parallelize( _nhood_enrichment_helper, collection=np.arange(n_perms), extractor=np.vstack, n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )(callback=_test, indices=indices, indptr=indptr, int_clust=int_clust, n_cls=n_cls, seed=seed) zscore = (count - perms.mean(axis=0)) / perms.std(axis=0) if copy: return zscore, count _save_data( adata, attr="uns", key=Key.uns.nhood_enrichment(cluster_key), data={ "zscore": zscore, "count": count }, time=start, )
def _heatmap( adata: AnnData, key: str, title: str = "", method: Optional[str] = None, cont_cmap: Union[str, mcolors.Colormap] = "viridis", annotate: bool = True, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> mpl.figure.Figure: _assert_categorical_obs(adata, key=key) cbar_kwargs = dict(cbar_kwargs) fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize) if method is not None: row_order, col_order, row_link, col_link = _dendrogram(adata.X, method, optimal_ordering=adata.n_obs <= 1500) else: row_order = col_order = np.arange(len(adata.uns[Key.uns.colors(key)])) row_order = row_order[::-1] row_labels = adata.obs[key][row_order] data = adata[row_order, col_order].X row_cmap, col_cmap, row_norm, col_norm, n_cls = _get_cmap_norm(adata, key, order=(row_order, col_order)) row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm) col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm) norm = mpl.colors.Normalize(vmin=kwargs.pop("vmin", np.nanmin(data)), vmax=kwargs.pop("vmax", np.nanmax(data))) cont_cmap = copy(plt.get_cmap(cont_cmap)) cont_cmap.set_bad(color="grey") im = ax.imshow(data[::-1], cmap=cont_cmap, norm=norm) ax.grid(False) ax.tick_params(top=False, bottom=False, labeltop=False, labelbottom=False) ax.set_xticks([]) ax.set_yticks([]) if annotate: _annotate_heatmap(im, cmap=cont_cmap, **kwargs) divider = make_axes_locatable(ax) row_cats = divider.append_axes("left", size="2%", pad=0) col_cats = divider.append_axes("top", size="2%", pad=0) cax = divider.append_axes("right", size="1%", pad=0.1) if method is not None: # cluster rows but don't plot dendrogram col_ax = divider.append_axes("top", size="5%") sch.dendrogram(col_link, no_labels=True, ax=col_ax, color_threshold=0, above_threshold_color="black") col_ax.axis("off") _ = fig.colorbar( im, cax=cax, ticks=np.linspace(norm.vmin, norm.vmax, 10), orientation="vertical", format="%0.2f", **cbar_kwargs, ) # column labels colorbar c = fig.colorbar(col_sm, cax=col_cats, orientation="horizontal") c.set_ticks([]) (col_cats if method is None else col_ax).set_title(title) # row labels colorbar c = fig.colorbar(row_sm, cax=row_cats, orientation="vertical", ticklocation="left") c.set_ticks(np.arange(n_cls) + 0.5) c.set_ticklabels(row_labels) c.set_label(key) return fig
def interaction_matrix( adata: AnnData, cluster_key: str, connectivity_key: Optional[str] = None, normalized: bool = False, copy: bool = False, weights: bool = False, ) -> Optional[np.ndarray]: """ Compute interaction matrix for clusters. Parameters ---------- %(adata)s %(cluster_key)s %(conn_key)s normalized If `True`, each row is normalized to sum to 1. %(copy)s weights Whether to use edge weights or binarize. Returns ------- If ``copy = True``, returns the interaction matrix. Otherwise, modifies the ``adata`` with the following key: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix. """ connectivity_key = Key.obsp.spatial_conn(connectivity_key) _assert_categorical_obs(adata, cluster_key) _assert_connectivity_key(adata, connectivity_key) cats = adata.obs[cluster_key] mask = ~pd.isnull(cats).values cats = cats.loc[mask] if not len(cats): raise RuntimeError( f"After removing NaNs in `adata.obs[{cluster_key!r}]`, none remain." ) g = adata.obsp[connectivity_key] g = g[mask, :][:, mask] n_cats = len(cats.cat.categories) if weights: g_data = g.data else: g_data = np.broadcast_to(1, shape=len(g.data)) if pd.api.types.is_bool_dtype(g.dtype) or pd.api.types.is_integer_dtype( g.dtype): dtype = np.intp else: dtype = np.float_ output = np.zeros((n_cats, n_cats), dtype=dtype) _interaction_matrix(g_data, g.indices, g.indptr, cats.cat.codes.to_numpy(), output) if normalized: output = output / output.sum(axis=1).reshape((-1, 1)) if copy: return output _save_data(adata, attr="uns", key=Key.uns.interaction_matrix(cluster_key), data=output)
def test( self, cluster_key: str, clusters: Optional[Cluster_t] = None, n_perms: int = 1000, threshold: float = 0.01, seed: Optional[int] = None, corr_method: Optional[str] = None, corr_axis: Union[str, CorrAxis] = CorrAxis.INTERACTIONS.v, alpha: float = 0.05, copy: bool = False, key_added: Optional[str] = None, numba_parallel: Optional[bool] = None, **kwargs: Any, ) -> Optional[Mapping[str, pd.DataFrame]]: """ Perform the permutation test as described in :cite:`cellphonedb`. Parameters ---------- %(cluster_key)s clusters Clusters from :attr:`anndata.AnnData.obs` ``['{{cluster_key}}']``. Can be specified either as a sequence of :class:`tuple` or just a sequence of cluster names, in which case all combinations considered. %(n_perms)s threshold Do not perform permutation test if any of the interacting components is being expressed in less than ``threshold`` percent of cells within a given cluster. %(seed)s %(corr_method)s corr_axis Axis over which to perform the FDR correction. Only used when ``corr_method != None``. Valid options are: - `{fa.INTERACTIONS.s!r}` - correct interactions by performing FDR correction across the clusters. - `{fa.CLUSTERS.s!r}` - correct clusters by performing FDR correction across the interactions. alpha Significance level for FDR correction. Only used when ``corr_method != None``. %(copy)s key_added Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``. If `None`, ``'{{cluster_key}}_ligrec'`` will be used. %(numba_parallel)s %(parallelize)s Returns ------- %(ligrec_test_returns)s """ _assert_positive(n_perms, name="n_perms") _assert_categorical_obs(self._adata, key=cluster_key) if corr_method is not None: corr_axis = CorrAxis(corr_axis) if TYPE_CHECKING: assert isinstance(corr_axis, CorrAxis) if len(self._adata.obs[cluster_key].cat.categories) <= 1: raise ValueError( f"Expected at least `2` clusters, found `{len(self._adata.obs[cluster_key].cat.categories)}`." ) if TYPE_CHECKING: assert isinstance(self.interactions, pd.DataFrame) assert isinstance(self._filtered_data, pd.DataFrame) interactions = self.interactions[[SOURCE, TARGET]] self._filtered_data["clusters"] = self._adata.obs[cluster_key].astype( "string").astype("category").values if clusters is None: clusters = list( map(str, self._adata.obs[cluster_key].cat.categories)) if all(isinstance(c, str) for c in clusters): clusters = list(product( clusters, repeat=2)) # type: ignore[no-redef,assignment] clusters = sorted( _check_tuple_needles( clusters, # type: ignore[arg-type] self._filtered_data["clusters"].cat.categories, msg="Invalid cluster `{0!r}`.", reraise=True, )) clusters_flat = list({c for cs in clusters for c in cs}) data = self._filtered_data.loc[ np.isin(self._filtered_data["clusters"], clusters_flat), :] data["clusters"] = data["clusters"].cat.remove_unused_categories() cat = data["clusters"].cat cluster_mapper = dict(zip(cat.categories, range(len(cat.categories)))) gene_mapper = dict(zip(data.columns[:-1], range(len(data.columns) - 1))) # -1 for 'clusters' data.columns = [ gene_mapper[c] if c != "clusters" else c for c in data.columns ] clusters_ = np.array([[cluster_mapper[c1], cluster_mapper[c2]] for c1, c2 in clusters], dtype=np.uint32) cat.rename_categories(cluster_mapper, inplace=True) # much faster than applymap (tested on 1M interactions) interactions_ = np.vectorize(lambda g: gene_mapper[g])( interactions.values) n_jobs = _get_n_cores(kwargs.pop("n_jobs", None)) start = logg.info( f"Running `{n_perms}` permutations on `{len(interactions)}` interactions " f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)" ) res = _analysis( data, interactions_, clusters_, threshold=threshold, n_perms=n_perms, seed=seed, n_jobs=n_jobs, numba_parallel=numba_parallel, **kwargs, ) res = { "means": _create_sparse_df( res.means, index=pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]), columns=pd.MultiIndex.from_tuples( clusters, names=["cluster_1", "cluster_2"]), fill_value=0, ), "pvalues": _create_sparse_df( res.pvalues, index=pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET]), columns=pd.MultiIndex.from_tuples( clusters, names=["cluster_1", "cluster_2"]), fill_value=np.nan, ), "metadata": self.interactions[self.interactions.columns.difference( [SOURCE, TARGET])], } res["metadata"].index = res["means"].index.copy() if TYPE_CHECKING: assert isinstance(res, dict) if corr_method is not None: logg.info(f"Performing FDR correction across the `{corr_axis.v}` " f"using method `{corr_method}` at level `{alpha}`") res["pvalues"] = _fdr_correct(res["pvalues"], corr_method, corr_axis, alpha=alpha) if copy: logg.info("Finish", time=start) return res _save_data(self._adata, attr="uns", key=Key.uns.ligrec(cluster_key, key_added), data=res, time=start)
def co_occurrence( adata: AnnData, cluster_key: str, spatial_key: str = Key.obsm.spatial, n_steps: int = 50, copy: bool = False, n_splits: Optional[int] = None, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = True, ) -> Optional[Tuple[np.ndarray, np.ndarray]]: """ Compute co-occurrence probability of clusters. The co-occurrence is computed across ``n_steps`` distance thresholds in spatial dimensions. Parameters ---------- %(adata)s %(cluster_key)s %(spatial_key)s n_steps Number of distance thresholds at which co-occurrence is computed. %(copy)s n_splits Number of splits in which to divide the spatial coordinates in :attr:`anndata.AnnData.obsm` ``['{spatial_key}']``. %(parallelize)s Returns ------- If ``copy = True``, returns the co-occurrence probability and the distance thresholds intervals. Otherwise, modifies the ``adata`` with the following keys: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['occ']`` - the co-occurrence probabilities across interval thresholds. - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds computed at ``n_steps``. """ _assert_categorical_obs(adata, key=cluster_key) _assert_spatial_basis(adata, key=spatial_key) spatial = adata.obsm[spatial_key].astype(fp) original_clust = adata.obs[cluster_key] # find minimum, second minimum and maximum for thresholding thres_min, thres_max = _find_min_max(spatial) # annotate cluster idx clust_map = { v: i for i, v in enumerate(original_clust.cat.categories.values) } labs = np.array([clust_map[c] for c in original_clust], dtype=ip) labs_unique = np.array(list(clust_map.values()), dtype=ip) # create intervals thresholds interval = np.linspace(thres_min, thres_max, num=n_steps, dtype=fp) n_obs = spatial.shape[0] if n_splits is None: size_arr = (n_obs**2 * 4) / 1024 / 1024 # calc expected mem usage if size_arr > 2_000: s = 1 while 2_048 < (n_obs / s): s += 1 n_splits = s logg.warning( f"`n_splits` was automatically set to: {n_splits}\n" f"preventing a NxN with N={n_obs} distance matrix to be created" ) else: n_splits = 1
def centrality_scores( adata: AnnData, cluster_key: str, score: Optional[Union[str, Sequence[str]]] = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({}), palette: Palette_t = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs: Any, ) -> None: """ Plot centrality scores. The centrality scores are computed by :func:`squidpy.gr.centrality_scores`. Parameters ---------- %(adata)s %(cluster_key)s score Whether to plot all scores or only selected ones. legend_kwargs Keyword arguments for :func:`matplotlib.pyplot.legend`. %(cat_plotting)s Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) df = _get_data(adata, cluster_key=cluster_key, func_name="centrality_scores").copy() legend_kwargs = dict(legend_kwargs) if "loc" not in legend_kwargs: legend_kwargs["loc"] = "center left" legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5)) scores = df.columns.values df[cluster_key] = df.index.values clusters = adata.obs[cluster_key].cat.categories palette = _get_palette(adata, cluster_key=cluster_key, categories=clusters) if palette is None else palette score = scores if score is None else score score = _assert_non_empty_sequence(score, name="centrality scores") score = sorted(_get_valid_values(score, scores)) fig, axs = plt.subplots(1, len(score), figsize=figsize, dpi=dpi, constrained_layout=True) axs = np.ravel(axs) # make into iterable for g, ax in zip(score, axs): sns.scatterplot( x=g, y=cluster_key, data=df, hue=cluster_key, hue_order=clusters, palette=palette, ax=ax, **kwargs, ) ax.set_title(str(g).replace("_", " ").capitalize()) ax.set_xlabel("value") ax.set_yticks([]) ax.legend(**legend_kwargs) if save is not None: save_fig(fig, path=save)
def co_occurrence( adata: AnnData, cluster_key: str, palette: Palette_t = None, clusters: Optional[Union[str, Sequence[str]]] = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> None: """ Plot co-occurrence probability ratio for each cluster. The co-occurrence is computed by :func:`squidpy.gr.co_occurrence`. Parameters ---------- %(adata)s %(cluster_key)s clusters Cluster instances for which to plot conditional probability. %(cat_plotting)s legend_kwargs Keyword arguments for :func:`matplotlib.pyplot.legend`. kwargs Keyword arguments for :func:`seaborn.lineplot`. Returns ------- %(plotting_returns)s """ _assert_categorical_obs(adata, key=cluster_key) occurrence_data = _get_data(adata, cluster_key=cluster_key, func_name="co_occurrence") legend_kwargs = dict(legend_kwargs) if "loc" not in legend_kwargs: legend_kwargs["loc"] = "center left" legend_kwargs.setdefault("bbox_to_anchor", (1, 0.5)) out = occurrence_data["occ"] interval = occurrence_data["interval"][1:] categories = adata.obs[cluster_key].cat.categories clusters = categories if clusters is None else clusters clusters = _assert_non_empty_sequence(clusters, name="clusters") clusters = sorted(_get_valid_values(clusters, categories)) palette = _get_palette(adata, cluster_key=cluster_key, categories=categories) if palette is None else palette fig, axs = plt.subplots( 1, len(clusters), figsize=(5 * len(clusters), 5) if figsize is None else figsize, dpi=dpi, constrained_layout=True, ) axs = np.ravel(axs) # make into iterable for g, ax in zip(clusters, axs): idx = np.where(categories == g)[0][0] df = pd.DataFrame(out[idx, :, :].T, columns=categories).melt(var_name=cluster_key, value_name="probability") df["distance"] = np.tile(interval, len(categories)) sns.lineplot( x="distance", y="probability", data=df, dashes=False, hue=cluster_key, hue_order=categories, palette=palette, ax=ax, **kwargs, ) ax.legend(**legend_kwargs) ax.set_title(rf"$\frac{{p(exp|{g})}}{{p(exp)}}$") ax.set_ylabel("value") if save is not None: save_fig(fig, path=save)