def composition( adata: AnnData, key, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[float] = None, save: Optional[Union[str, Path]] = None, ) -> None: """ Plot pie chart for categorical annotation. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/composition.png :width: 400px :align: center Params ------ adata Annotated data object. key Key in :paramref:`adata` `.obs` containing categorical observation. figsize Size of the figure. dpi Dots per inch. save Filename where to save the plots. If `None`, just shows the plot. Returns ------- None Nothing, just plots the similarity matrix. Optionally saves the figure based on :paramref:`save`. """ if key not in adata.obs: raise KeyError(f"Key `{key!r}` not found in `adata.obs`.") if not is_categorical_dtype(adata.obs[key]): raise TypeError( f"Observation `adata.obs[{key!r}]` is not categorical.") cats = adata.obs[key].cat.categories colors = adata.uns.get(f"{key}_colors", None) x = [np.sum(adata.obs[key] == cl) for cl in cats] cats_frac = x / np.sum(x) # plot these fractions in a pie plot fig, ax = plt.subplots(figsize=figsize, dpi=dpi) ax.pie(x=cats_frac, labels=cats, colors=colors) ax.set_title(f"Composition by {key}") if save is not None: save_fig(fig, save) fig.show()
def _trends_helper( adata: AnnData, models: Dict[str, Dict[str, Any]], gene: str, ln_key: str, lineage_names: Optional[Sequence[str]] = None, same_plot: bool = False, sharey: bool = True, cmap=None, fig: mpl.figure.Figure = None, ax: mpl.axes.Axes = None, save: Optional[Union[str, Path]] = None, **kwargs, ) -> None: """ Plot an expression gene for some lineages. Params ------ adata: :class:`anndata.AnnData` Annotated data object. models Gene and lineage specific models can be specified. Use `'*'` to indicate all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`. gene Name of the gene in `adata.var_names`. fig Figure to use, if `None`, create a new one. ax Ax to use, if `None`, create a new one. save Filename where to save the plot. If `None`, just shows the plots. **kwargs Keyword arguments for :meth:`cellrank.ul.models.Model.plot`. Returns ------- None Nothing, just plots the trends. Optionally saves the figure based on :paramref:`save`. """ n_lineages = len(lineage_names) if same_plot: if fig is None and ax is None: fig, ax = plt.subplots( 1, figsize=kwargs.get("figsize", None) or (15, 10), constrained_layout=True, ) axes = [ax] * len(lineage_names) else: fig, axes = plt.subplots( ncols=n_lineages, figsize=kwargs.get("figsize", None) or (6 * n_lineages, 6), sharey=sharey, constrained_layout=True, ) axes = np.ravel(axes) percs = kwargs.pop("perc", None) if percs is None or not isinstance(percs[0], (tuple, list)): percs = [percs] same_perc = False # we need to show colorbar always if percs differ if len(percs) != n_lineages or n_lineages == 1: if len(percs) != 1: raise ValueError( f"Percentile must be a collection of size `1` or `{n_lineages}`, got `{len(percs)}`." ) same_perc = True percs = percs * n_lineages hide_cells = kwargs.pop("hide_cells", False) show_cbar = kwargs.pop("show_cbar", True) lineage_color = kwargs.pop("color", "black") lc = (cmap.colors if cmap is not None and hasattr(cmap, "colors") else adata.uns.get(f"{_colors(ln_key)}", cm.Set1.colors)) for i, (name, ax, perc) in enumerate(zip(lineage_names, axes, percs)): title = name if name is not None else "No lineage" models[gene][name].plot( ax=ax, fig=fig, perc=perc, show_cbar=True if not same_perc else False if not show_cbar else (i == n_lineages - 1), title=title, hide_cells=hide_cells or (same_plot and i != n_lineages - 1), same_plot=same_plot, color=lc[i] if same_plot and name is not None else lineage_color, ylabel=gene if not same_plot or name is None else "expression", **kwargs, ) if same_plot and lineage_names != [None]: ax.set_title(gene) ax.legend() if save is not None: save_fig(fig, save)
def heatmap( adata: AnnData, model: _model_type, genes: Sequence[str], final: bool = True, kind: str = "lineages", lineages: Optional[Union[str, Sequence[str]]] = None, start_lineage: Optional[Union[str, Sequence[str]]] = None, end_lineage: Optional[Union[str, Sequence[str]]] = None, lineage_height: float = 0.1, cluster_genes: bool = False, xlabel: Optional[str] = None, cmap: colors.ListedColormap = cm.Spectral_r, n_jobs: Optional[int] = 1, backend: str = "multiprocessing", hspace: float = 0.25, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, show_progress_bar: bool = True, **kwargs, ) -> None: """ Plot a heatmap of smoothed gene expression along specified lineages. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/heatmap.png :width: 400px :align: center Params ------ adata : :class:`anndata.AnnData` Annotated data object. model Model to fit. - If a :class:`dict`, gene and lineage specific models can be specified. Use `'*'` to indicate all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`. genes Genes in :paramref:`adata` `.var_names` to plot. final Whether to consider cells going to final states or vice versa. kind Variant of the heatmap. - If `'genes'`, group by :paramref:`genes` for each lineage in :paramref:`lineage_names`. - If `'lineages'`, group by :paramref:`lineage_names` for each gene in :paramref:`genes`. lineage_names Names of the lineages for which to plot. start_lineage Lineage from which to select cells with lowest pseudotime as starting points. If specified, the trends start at the earliest pseudotime point within that lineage, otherwise they start from time `0`. end_lineage Lineage from which to select cells with highest pseudotime as endpoints. If specified, the trends end at the latest pseudotime point within that lineage, otherwise, it is determined automatically. lineage_height Height of a bar when :paramref:`kind` ='lineages'. xlabel Label on the x-axis. If `None`, it is determined based on :paramref:`time_key`. cluster_genes Whether to use :func:`seaborn.clustermap` when :paramref:`kind` `='lineages'`. cmap Colormap to use when visualizing the smoothed expression. n_jobs Number of parallel jobs. If `-1`, use all available cores. If `None` or `1`, the execution is sequential. backend Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. figsize Size of the figure. If `None`, it will be set to (15, len(:paramref:`genes`) + len(:paramref:`lineage_names`)). dpi Dots per inch. save Filename where to save the plot. If `None`, just shows the plot. show_progress_bar Whether to show a progress bar tracking models fitted. **kwargs Keyword arguments for :meth:`cellrank.ul.models.Model.prepare`. Returns ------- None Nothing, just plots the heatmap variant depending on :paramref:`kind`. Optionally saves the figure based on :paramref:`save`. """ def gene_per_lineage(): def color_fill_rec(ax, x, y1, y2, colors=None, cmap=cmap, **kwargs): dx = x[1] - x[0] for (color, x, y1, y2) in zip(cmap(colors), x, y1, y2): ax.add_patch( plt.Rectangle((x, y1), dx, y2 - y1, color=color, ec=color, **kwargs)) ax.plot(x, y2, lw=0) fig, axes = plt.subplots( nrows=len(genes), figsize=(15, len(genes) + len(lineages)) if figsize is None else figsize, dpi=dpi, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) for ax, (gene, models) in zip(axes, data.items()): c = np.array([m.y_test for m in models.values()]) c_min, c_max = np.nanmin(c), np.nanmax(c) norm = colors.Normalize(vmin=c_min, vmax=c_max) ix = 0 ys = [ix] for x, c in ((m.x_test, m.y_test) for m in models.values()): y = np.ones_like(x) color_fill_rec(ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)) ix += lineage_height ys.append(ix) xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.nanmin(xs), np.nanmax(xs) ax.set_xticks(np.linspace(x_min, x_max, 11)) ax.set_yticks(np.array(ys) + lineage_height / 2) ax.set_yticklabels(lineages) ax.set_title(gene) ax.set_ylabel("Lineage") for pos in ["top", "bottom", "left", "right"]: ax.spines[pos].set_visible(False) cax, _ = mpl.colorbar.make_axes(ax) _ = mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=cmap, label="Expression") ax.tick_params( top=False, bottom=False, left=False, right=False, labelleft=True, labelbottom=False, ) ax.xaxis.set_major_formatter(FormatStrFormatter("%.1f")) ax.tick_params( top=False, bottom=True, left=False, right=False, labelleft=True, labelbottom=True, ) ax.set_xlabel(xlabel) return fig def lineage_per_gene(): data_t = defaultdict(dict) # transpose for gene, lns in data.items(): for ln, d in lns.items(): data_t[ln][gene] = d fig, ax = None, None if not cluster_genes: fig, axes = plt.subplots( nrows=len(lineages), figsize=(15, 5 + len(genes)) if figsize is None else figsize, dpi=dpi, ) fig.subplots_adjust(hspace=hspace, bottom=0) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) else: axes = [None] * len(data) for ax, (lname, models) in zip(axes, data_t.items()): df = pd.DataFrame([m.y_test for m in models.values()], index=genes) df.index.name = "Genes" if cluster_genes: g = sns.clustermap( df, cmap=cmap, xticklabels=False, cbar_kws={"label": "Expression"}, row_cluster=True, col_cluster=False, ) g.ax_heatmap.set_title(lname) fig = g.fig else: xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.nanmin(xs), np.nanmax(xs) sns.heatmap( df, ax=ax, cmap=cmap, xticklabels=False, cbar_kws={"label": "Expression"}, ) ax.set_title(lname) ax.set_xticks(np.linspace(0, len(df.columns), 10)) ax.set_xticklabels( list( map(lambda n: round(n, 3), np.linspace(x_min, x_max, 10))), rotation=90, ) if not cluster_genes: ax.set_xlabel(xlabel) return fig lineage_key = str(LinKey.FORWARD if final else LinKey.BACKWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[lineage_key].names for lineage_name in lineages: _ = adata.obsm[lineage_key][lineage_name] if isinstance(genes, str): genes = [genes] check_collection(adata, genes, "var_names") if isinstance(start_lineage, (str, type(None))): start_lineage = [start_lineage] * len(lineages) if isinstance(end_lineage, (str, type(None))): end_lineage = [end_lineage] * len(lineages) xlabel = kwargs.get("time_key", None) if xlabel is None else xlabel _ = kwargs.pop("start_lineage", None) _ = kwargs.pop("end_lineage", None) for typp, clusters in zip(["Start", "End"], [start_lineage, end_lineage]): for cl in filter(lambda c: c is not None, clusters): if cl not in lineages: raise ValueError( f"{typp} lineage `{cl!r}` not found in lineage names.") kwargs["models"] = _create_models(model, genes, lineages) if _is_any_gam_mgcv(kwargs["models"]): logg.debug( "DEBUG: Setting backend to multiprocessing because model is `GamMGCV`" ) backend = "multiprocessing" n_jobs = _get_n_cores(n_jobs, len(genes)) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") data = parallelize( _fit, genes, unit="gene", backend=backend, n_jobs=n_jobs, extractor=lambda data: {k: v for d in data for k, v in d.items()}, show_progress_bar=show_progress_bar, )(lineages, start_lineage, end_lineage, **kwargs) logg.info(" Finish", time=start) logg.debug(f"DEBUG: Plotting {kind} heatmap") if kind == "genes": fig = gene_per_lineage() elif kind == "lineages": fig = lineage_per_gene() else: raise ValueError( f"Unknown heatmap kind `{kind!r}`. Valid options are: `'lineages'`, `'genes'`." ) if save is not None and fig is not None: save_fig(fig, save)
def similarity_plot( adata: AnnData, cluster_key: str = "clusters", clusters: Optional[List[str]] = None, n_samples: int = 1000, cmap: mpl.colors.ListedColormap = cm.viridis, fontsize: float = 14, rotation: float = 45, title: Optional[str] = "similarity", figsize: Tuple[float, float] = (12, 10), dpi: Optional[int] = None, final: bool = True, save: Optional[Union[str, Path]] = None, ) -> None: """ Compare clusters with respect to their root/final probabilities. For each cluster, we compute how likely an 'average cell' is to go towards to final states/come from the root states. We then compare these averaged probabilities using Cramér's V statistic, see `here <https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V>`_. The similarity is defined as :math:`1 - Cramér's V`. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/similarity_plot.png :width: 400px :align: center Params ------ adata: :class:`anndata.AnnData` Annotated data object. cluster_key Key in :paramref:`adata` `.obs` corresponding the the clustering. clusters Clusters in :paramref:`adata` `.obs` to consider. If `None`, all cluster will be considered. n_samples Number of samples per cluster. cmap Colormap to use. fontsize Font size of the labels. rotation Rotation of labels on x-axis. figsize Size of the figure. title Title of the figure. dpi Dots per inch. final Whether to consider cells going to final states or vice versa. save Filename where to save the plot. If `None`, just shows the plot. Returns ------- None Nothing, just plots the similarity matrix. Optionally saves the figure based on :paramref:`save`. """ logg.debug("DEBUG: Getting the counts") data = _counts( adata, cluster_key=cluster_key, clusters=clusters, n_samples=n_samples, final=final, ) cluster_names = list(data.keys()) logg.debug("DEBUG: Calculating Cramer`s V statistic") sim = [[1 - _cramers_v(data[name2], data[name]) for name in cluster_names] for name2 in cluster_names] # Plotting function fig, ax = plt.subplots(figsize=figsize, dpi=dpi) im = ax.imshow(sim, cmap=cmap) ax.set_xticks(range(len(cluster_names))) ax.set_yticks(range(len(cluster_names))) ax.set_xticklabels(cluster_names, fontsize=fontsize, rotation=rotation) ax.set_yticklabels(cluster_names, fontsize=fontsize) ax.set_title(title) ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True) cbar = ax.figure.colorbar(im, ax=ax, norm=mpl.colors.Normalize(vmin=0, vmax=1)) cbar.set_ticks(np.linspace(0, 1, 10)) if save is not None: save_fig(fig, save) fig.show()
def gene_trends( adata: AnnData, model: _model_type, genes: Union[str, Sequence[str]], lineages: Optional[Union[str, Sequence[str]]] = None, data_key: str = "X", final: bool = True, start_lineage: Optional[Union[str, Sequence[str]]] = None, end_lineage: Optional[Union[str, Sequence[str]]] = None, conf_int: bool = True, same_plot: bool = False, hide_cells: bool = False, perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float, float]]]] = None, lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None, abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis, cell_color: str = "black", color: str = "black", cell_alpha: float = 0.6, lineage_alpha: float = 0.2, size: float = 15, lw: float = 2, show_cbar: bool = True, margins: float = 0.015, sharey: bool = False, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, ncols: int = 2, n_jobs: Optional[int] = 1, backend: str = "multiprocessing", ext: str = "png", suptitle: Optional[str] = None, save: Optional[Union[str, Path]] = None, dirname: Optional[str] = None, plot_kwargs: Mapping = MappingProxyType({}), show_progres_bar: bool = True, **kwargs, ) -> None: """ Plot gene expression trends along lineages. Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This function accepts any `scikit-learn` model wrapped in :class:`cellrank.ul.models.SKLearnModel` to fit gene expression, where we take the lineage weights into account in the loss function. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/gene_trends.png :width: 400px :align: center Params ------ adata : :class:`anndata.AnnData` Annotated data object. genes Genes in :paramref:`adata` `.var_names` to plot. model Model to fit. - If a :class:`dict`, gene and lineage specific models can be specified. Use `'*'` to indicate all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`. lineage_names Lineages names for which to show the gene expression. data_key Key in :paramref:`adata` `.layers` or `'X'` for :paramref:`adata` `.X` where the data is stored. final Whether to consider cells going to final states or vice versa. start_lineage Lineage from which to select cells with lowest pseudotime as starting points. If specified, the trends start at the earliest pseudotime within that lineage, otherwise they start from time `0`. end_lineage Lineage from which to select cells with highest pseudotime as endpoints. If specified, the trends end at the latest pseudotime within that lineage, otherwise, it is determined automatically. conf_int Whether to compute and show confidence intervals. same_plot Whether to plot all lineages for each gene in the same plot. hide_cells If `True`, hide all the cells. perc Percentile for colors. Valid values are in range `[0, 100]`. This can improve visualization. Can be specified separately for each lineage separately. lineage_cmap Colormap to use when coloring in the lineages. Only used when :paramref:`same_plot` `=True`. abs_prob_cmap Colormap to use when visualizing the absorption probabilities for each lineage. Only used when :paramref:`same_plot` `=False`. cell_color Color of the cells when not visualizing absorption probabilities. Only used when :paramref:`same_plot` `=True`. color Color for the lineages, when each lineage is on separate plot, otherwise according to :paramref:`lineage_cmap`. cell_alpha Alpha channel for cells. lineage_alpha Alpha channel for lineage confidence intervals. size Size of the points. lw Line width of the smoothed values. show_cbar Whether to show colorbar. Always shown when percentiles for lineages differ. margins Margins around the plot. sharey Whether to share y-axis. Only used when :paramref:`same_plot` `=False`. figsize Size of the figure. dpi Dots per inch. ncols Number of columns of the plot when plotting multiple genes. Only used when :paramref:`same_plot` `=True`. suptitle Suptitle of the figure. Only used when :paramref:`same_plot` `=True`. n_jobs Number of parallel jobs. If `-1`, use all available cores. If `None` or `1`, the execution is sequential. backend Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. ext Extension to use when saving files, such as `'pdf'`. Only used when :paramref:`same_plot` `=False`. save Filename where to save the plots. If `None`, just show the plots. dirname Directory where to save the plots, one per gene in :paramref:`genes`. If `None`, just show the plots. Only used when :paramref:`same_plot` `=False`. The figures will be saved as :paramref:`dirname` /`{gene}`. :paramref:`ext`. plot_kwargs: Keyword arguments for :meth:`cellrank.ul.models.Model.plot`. kwargs Keyword arguments for :meth:`cellrank.ul.models.Model.prepare`. Returns ------- None Nothings just plots and optionally saves the plots. """ if isinstance(genes, str): genes = [genes] genes = _make_unique(genes) if data_key != "obs": check_collection(adata, genes, "var_names") else: check_collection(adata, genes, "obs") nrows = int(np.ceil(len(genes) / ncols)) fig = None axes = [None] * len(genes) if same_plot: fig, axes = plt.subplots( nrows=nrows, ncols=ncols, sharey=sharey, figsize=(15 * ncols, 10 * nrows) if figsize is None else figsize, ) axes = np.ravel(axes) elif dirname is not None: figdir = sc.settings.figdir if figdir is None: raise RuntimeError( f"Invalid combination: figures directory `cellrank.settings.figdir` is `None`, " f"but dirname was specified to `{dirname}`.") if os.path.isabs(dirname): if not os.path.isdir(dirname): os.makedirs(dirname, exist_ok=True) elif not os.path.isdir(os.path.join(figdir, dirname)): os.makedirs(os.path.join(figdir, dirname), exist_ok=True) elif save is not None: logg.warning( "No directory specified for saving. Ignoring `save` argument") ln_key = str(LinKey.FORWARD if final else LinKey.BACKWARD) if ln_key not in adata.obsm: raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[ln_key].names elif isinstance(lineages, str): lineages = [lineages] elif all(map(lambda ln: ln is None, lineages)): # no lineage, all the weights are 1 lineages = [None] show_cbar = False logg.debug("DEBUG: All lineages are `None`, setting weights to be `1`") lineages = _make_unique(lineages) for ln in filter(lambda ln: ln is not None, lineages): _ = adata.obsm[ln_key][ln] n_lineages = len(lineages) if isinstance(start_lineage, (str, type(None))): start_lineage = [start_lineage] * n_lineages if isinstance(end_lineage, (str, type(None))): end_lineage = [end_lineage] * n_lineages if len(start_lineage) != n_lineages: raise ValueError( f"Expected the number of start lineages to be the same as number of lineages " f"({n_lineages}), found `{len(start_lineage)}`.") if len(end_lineage) != n_lineages: raise ValueError( f"Expected the number of end lineages to be the same as number of lineages " f"({n_lineages}), found `{len(start_lineage)}`.") kwargs["models"] = _create_models(model, genes, lineages) kwargs["data_key"] = data_key kwargs["final"] = final kwargs["conf_int"] = conf_int plot_kwargs = dict(plot_kwargs) if plot_kwargs.get("xlabel", None) is None: plot_kwargs["xlabel"] = kwargs.get("time_key", None) if _is_any_gam_mgcv(kwargs["models"]): logg.debug( "DEBUG: Setting backend to multiprocessing because model is `GamMGCV`" ) backend = "multiprocessing" n_jobs = _get_n_cores(n_jobs, len(genes)) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") models = parallelize( _fit, genes, unit="gene" if data_key != "obs" else "obs", backend=backend, n_jobs=n_jobs, extractor=lambda modelss: {k: v for m in modelss for k, v in m.items()}, show_progress_bar=show_progres_bar, )(lineages, start_lineage, end_lineage, **kwargs) logg.info(" Finish", time=start) logg.debug("DEBUG: Plotting trends") for i, (gene, ax) in enumerate(zip(genes, axes)): f = (None if (same_plot or dirname is None) else os.path.join( dirname, f"{gene}.{ext}")) _trends_helper( adata, models, gene=gene, lineage_names=lineages, ln_key=ln_key, same_plot=same_plot, hide_cells=hide_cells, perc=perc, cmap=lineage_cmap, abs_prob_cmap=abs_prob_cmap, cell_color=cell_color, color=color, alpha=cell_alpha, lineage_alpha=lineage_alpha, size=size, lw=lw, show_cbar=show_cbar, margins=margins, sharey=sharey, dpi=dpi, figsize=figsize, fig=fig, ax=ax, save=f, **plot_kwargs, ) if same_plot: for j in range(len(genes), len(axes)): axes[j].remove() fig.suptitle(suptitle) if save is not None: save_fig(fig, save)
def cluster_fates( adata: AnnData, cluster_key: Optional[str] = "louvain", clusters: Optional[Union[str, Sequence[str]]] = None, lineages: Optional[Union[str, Sequence[str]]] = None, mode: str = "bar", final: bool = True, basis: Optional[str] = None, show_cbar: bool = True, ncols: Optional[int] = None, sharey: bool = False, save: Optional[Union[str, Path]] = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, **kwargs, ) -> None: """ Produces plots that aggregate lineage probabilities to a cluster level. This can be used to investigate how likely a certain cluster is to go to the final cells, or in turn to have descended from the root cells. For mode `'paga'` and `'paga_pie'`, we use *PAGA*, see [Wolf19]_. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/cluster_fates.png :width: 400px :align: center Params ------ adata : :class:`anndata.AnnData` Annotated data object. cluster_key Key in :paramref:`adata` `.obs` containing the clusters. clusters Clusters to visualize. If `None`, all clusters will be plotted. lineages Lineages for which to visualize absorption probabilities. If `None`, use all available lineages. mode Type of plot to show. - If `'bar'`, plot barplots for specified :paramref:`clusters` and :paramref:`endpoints`. - If `'paga'`, plot `N` :func:`scanpy.pl.paga` plots, one for each endpoint in :paramref:`endpoints`. - If `'paga_pie'`, visualize absorption probabilities as a pie chart for each cluster for the given :paramref:`endpoints`. - If `'violin'`, group the data by lineages and plot the fate distribution per cluster. Best for looking at the distribution of fates within one cluster. dpi Dots per inch. final Whether to consider cells going to final states or vice versa. basis Basis for scatterplot to use when :paramref:`mode` `='paga_pie'`. If `None`, don't show the scatterplot. show_cbar Whether to show colorbar when :paramref:`mode` is `'paga_pie'`. ncols Number of columns when :paramref:`mode` is `'bar'` or `'paga'`. sharey Whether to share y-axis when :paramref:`mode` is `'bar'`. figsize Size of the figure. save Filename where to save the plots. If `None`, just shows the plot. legend_kwargs Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position. figsize Size of the figure. If `None`, it will be set automatically. dpi Dots per inch. kwargs Keyword arguments for :func:`scvelo.pl.paga`, :func:`scanpy.pl.violin` or :func:`matplotlib.pyplot.bar`, depending on :paramref:`mode`. Returns ------- None Nothing, just plots the fates for specified :paramref:`clusters` and :paramref:`lineages`. Optionally saves the figure based on :paramref:`save`. """ def plot_bar(): cols = 4 if ncols is None else ncols n_rows = ceil(len(clusters) / cols) fig = plt.figure(None, (3.5 * cols, 5 * n_rows) if figsize is None else figsize, dpi=dpi) fig.tight_layout() gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.7, hspace=0.9) ax = None colors = list(adata.obsm[lk][:, lin_names].colors) for g, k in zip(gs, d.keys()): current_ax = fig.add_subplot(g, sharey=ax) current_ax.bar( x=np.arange(len(lin_names)), height=d[k][0], color=colors, yerr=d[k][1], ecolor="black", capsize=10, **kwargs, ) if sharey: ax = current_ax current_ax.set_xticks(np.arange(len(lin_names))) current_ax.set_xticklabels(lin_names, rotation="vertical") current_ax.set_xlabel(points) current_ax.set_ylabel("Absorption probability") current_ax.set_title(k) return fig def plot_paga(): kwargs["save"] = None kwargs["show"] = False if "cmap" not in kwargs: kwargs["cmap"] = cm.viridis cols = len(lin_names) if ncols is None else ncols nrows = ceil(len(lin_names) / cols) fig, axes = plt.subplots( nrows, cols, figsize=(6 * cols, 4 * nrows) if figsize is None else figsize, constrained_layout=True, dpi=dpi, ) i = 0 axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes) vmin, vmax = np.inf, -np.inf for i, (ax, lineage_name) in enumerate(zip(axes, lin_names)): colors = [v[0][i] for v in d.values()] vmin, vmax = np.nanmin(colors + [vmin]), np.nanmax(colors + [vmax]) kwargs["ax"] = ax kwargs["colors"] = tuple(colors) kwargs["title"] = lineage_name sc.pl.paga(adata, **kwargs) if show_cbar: norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) cax, _ = mpl.colorbar.make_axes( ax, aspect=100) # new matplotlib feature _ = mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=kwargs["cmap"], label="Absorption probability") for ax in axes[i + 1:]: ax.remove() return fig def plot_paga_pie(): colors = list(adata.obsm[lk][:, lin_names].colors) colors = { i: odict(zip(colors, mean)) for i, (mean, _) in enumerate(d.values()) } fig, ax = plt.subplots(figsize=figsize, dpi=dpi) kwargs["ax"] = ax kwargs["show"] = False kwargs["colorbar"] = False # has to be disabled kwargs["node_colors"] = colors kwargs.pop("save", None) # we will handle saving kwargs["transitions"] = kwargs.get("transitions", None) kwargs["legend_loc"] = kwargs.get("legend_loc", None) or "on data" if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key scv.pl.paga(adata, **kwargs) if basis is not None and kwargs["legend_loc"] not in ("none", "on data"): first_legend = _position_legend( ax, legend_loc=kwargs["legend_loc"], **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=cluster_key, ) fig.add_artist(first_legend) if legend_kwargs.get("loc", None) is not None: # we need to use these, because scvelo can have its own handles and # they would be plotted here handles = [] for lineage_name, color in zip(lin_names, colors[0].keys()): handles += [ax.scatter([], [], label=lineage_name, c=color)] if len(colors[0].keys()) != len(adata.obsm[lk].names): handles += [ax.scatter([], [], label="Rest", c="grey")] ax.legend(**legend_kwargs, handles=handles, title=points) return fig def plot_violin(): kwargs["show"] = False kwargs.pop("ax", None) kwargs.pop("keys", None) kwargs.pop("save", None) # we will handle saving kwargs["groupby"] = cluster_key if kwargs.get("rotation", None) is None: kwargs["rotation"] = 90 data = adata.obsm[lk] to_clean = [] for i, name in enumerate(lin_names): if name not in adata.obs_keys(): to_clean.append(name) adata.obs[name] = np.array( data[:, name]) # TODO: better approach - dummy adata cols = len(lin_names) if ncols is None else ncols nrows = ceil(len(lin_names) / cols) fig, axes = plt.subplots( nrows, cols, figsize=(6 * cols, 4 * nrows) if figsize is None else figsize, dpi=dpi, ) if not isinstance(axes, np.ndarray): axes = [axes] axes = np.ravel(axes) i = 0 for i, (name, ax) in enumerate(zip(lin_names, axes)): ax.set_title(name) # ylabel not yet supported sc.pl.violin(adata, keys=[name], ax=ax, **kwargs) for ax in axes[i + 1:]: ax.remove() for name in to_clean: del adata.obs[name] return fig def plot_violin_no_cluster_key(): kwargs["show"] = False kwargs.pop("ax", None) kwargs.pop("keys", None) # don't care kwargs.pop("save", None) kwargs["groupby"] = points data = np.ravel(np.array(adata.obsm[lk]).T)[..., np.newaxis] dadata = AnnData(np.zeros_like(data)) dadata.obs["Absorption probability"] = data dadata.obs[points] = (pd.Series( np.ravel([[n] * adata.n_obs for n in adata.obsm[lk].names ])).astype("category").values) dadata.uns[f"{points}_colors"] = adata.obsm[lk].colors fig, ax = plt.subplots(figsize=figsize if figsize is not None else (8, 6), dpi=dpi) ax.set_title("All Clusters") sc.pl.violin(dadata, keys=["Absorption probability"], ax=ax, **kwargs) return fig if mode not in _cluster_fates_modes: raise ValueError( f"Invalid mode: `{mode!r}`. Valid options are: `{_cluster_fates_modes}`." ) if cluster_key is not None: if cluster_key not in adata.obs: raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") elif mode not in ("bar", "violin"): raise ValueError( f"Not specifying cluster key is only available for modes `'bar'` and `'violin'`." ) if cluster_key is not None: if clusters is not None: if isinstance(clusters, str): clusters = [clusters] clusters = _make_unique(clusters) if mode in ("paga", "paga_pie"): logg.debug( f"DEBUG: Setting `clusters` to all available ones because of `mode={mode!r}`" ) clusters = list(adata.obs[cluster_key].cat.categories) else: for cname in clusters: if cname not in adata.obs[cluster_key].cat.categories: raise KeyError( f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`" ) else: clusters = list(adata.obs[cluster_key].cat.categories) else: clusters = ["All"] lk = str(LinKey.FORWARD if final else LinKey.BACKWARD) points = "Endpoints" if final else "Startpoints" if lk not in adata.obsm: raise KeyError(f"Lineages key `{lk!r}` not found in `adata.obsm`.") if lineages is not None: if isinstance(lineages, str): lineages = [lineages] lineages = _make_unique(lineages) for ep in lineages: if ep not in adata.obsm[lk].names: raise ValueError( f"Endpoint `{ep!r}` not found in `adata.obsm[{lk!r}].names`." ) lin_names = list(lineages) else: # must be list for sc.pl.violin, else cats str lin_names = list(adata.obsm[lk].names) if mode == "violin" and clusters != ["All"]: # TODO: temporary fix, until subclassing is made ready names, colors = adata.obsm[lk].names, adata.obsm[lk].colors adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy() adata.obsm[lk] = Lineage(adata.obsm[lk], names=names, colors=colors) d = odict() for name in clusters: mask = (np.ones((adata.n_obs, ), dtype=np.bool) if name == "All" else (adata.obs[cluster_key] == name).values) mask = list(np.array(mask, dtype=np.bool)) data = adata.obsm[lk][mask, lin_names].X mean = np.nanmean(data, axis=0) std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0]) d[name] = [mean, std] logg.debug(f"DEBUG: Using mode: `{mode!r}`") if mode == "bar": fig = plot_bar() elif mode == "paga": if "paga" not in adata.uns: raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.") fig = plot_paga() elif mode == "paga_pie": if "paga" not in adata.uns: raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.") fig = plot_paga_pie() elif mode == "violin": fig = plot_violin_no_cluster_key( ) if cluster_key is None else plot_violin() else: raise ValueError( f"Invalid mode `{mode!r}`. Valid options are: `{_cluster_fates_modes}`." ) if save is not None: save_fig(fig, save) fig.show()
def cluster_fates( adata: AnnData, cluster_key: Optional[str] = "louvain", clusters: Optional[Union[str, Sequence[str]]] = None, lineages: Optional[Union[str, Sequence[str]]] = None, mode: str = "bar", final: bool = True, basis: Optional[str] = None, show_cbar: bool = True, ncols: Optional[int] = None, sharey: bool = False, save: Optional[Union[str, Path]] = None, legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, **kwargs, ) -> None: """ Plot aggregate lineage probabilities at a cluster level. This can be used to investigate how likely a certain cluster is to go to the final states, or in turn to have descended from the root states. For mode `'paga'` and `'paga_pie'`, we use *PAGA*, see [Wolf19]_. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/cluster_fates.png :width: 400px :align: center Params ------ adata : :class:`anndata.AnnData` Annotated data object. cluster_key Key in :paramref:`adata` `.obs` containing the clusters. clusters Clusters to visualize. If `None`, all clusters will be plotted. lineages Lineages for which to visualize absorption probabilities. If `None`, use all available lineages. mode Type of plot to show. - `'bar'`: barplot, one panel per cluster - `'paga'`: scanpy's PAGA, one per root/final state, colored in by fate - `'paga_pie'`: scanpy's PAGA with pie charts indicating aggregated fates - `'violin'`: violin plots, one per root/final state - `'heatmap'`: seaborn heatmap, showing average fates per cluster - `'clustermap'`: same as heatmap, but with dendrogram dpi Dots per inch. final Whether to consider cells going to final states or vice versa. basis Basis for scatterplot to use when :paramref:`mode` `='paga_pie'`. If `None`, don't show the scatterplot. show_cbar Whether to show colorbar when :paramref:`mode` is `'paga_pie'`. ncols Number of columns when :paramref:`mode` is `'bar'` or `'paga'`. sharey Whether to share y-axis when :paramref:`mode` is `'bar'`. figsize Size of the figure. save Filename where to save the plots. If `None`, just shows the plot. legend_kwargs Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position. For `mode='paga_pie'` and `basis='...'`, this controls the placement of the absorption probabilities legend. figsize Size of the figure. If `None`, it will be set automatically. dpi Dots per inch. kwargs Keyword arguments for :func:`scvelo.pl.paga`, :func:`scanpy.pl.violin` or :func:`matplotlib.pyplot.bar`, depending on :paramref:`mode`. Returns ------- None Nothing, just plots the fates for specified :paramref:`clusters` and :paramref:`lineages`. Optionally saves the figure based on :paramref:`save`. """ def plot_bar(): cols = 4 if ncols is None else ncols n_rows = ceil(len(clusters) / cols) fig = plt.figure(None, (3.5 * cols, 5 * n_rows) if figsize is None else figsize, dpi=dpi) fig.tight_layout() gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.7, hspace=0.9) ax = None colors = list(adata.obsm[lk][:, lin_names].colors) for g, k in zip(gs, d.keys()): current_ax = fig.add_subplot(g, sharey=ax) current_ax.bar( x=np.arange(len(lin_names)), height=d[k][0], color=colors, yerr=d[k][1], ecolor="black", capsize=10, **kwargs, ) if sharey: ax = current_ax current_ax.set_xticks(np.arange(len(lin_names))) current_ax.set_xticklabels( lin_names, rotation=xrot if has_xrot else "vertical") if not is_all: current_ax.set_xlabel(points) current_ax.set_ylabel("probability") current_ax.set_title(k) return fig def plot_paga(): kwargs["save"] = None kwargs["show"] = False if "cmap" not in kwargs: kwargs["cmap"] = cm.viridis cols = len(lin_names) if ncols is None else ncols nrows = ceil(len(lin_names) / cols) fig, axes = plt.subplots( nrows, cols, figsize=(7 * cols, 4 * nrows) if figsize is None else figsize, constrained_layout=True, dpi=dpi, ) i = 0 axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes) vmin, vmax = np.inf, -np.inf if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key for i, (ax, lineage_name) in enumerate(zip(axes, lin_names)): colors = [v[0][i] for v in d.values()] kwargs["ax"] = ax kwargs["colors"] = tuple(colors) kwargs["title"] = f"{dir_prefix} {lineage_name}" scv.pl.paga(adata, **kwargs) if show_cbar: norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) cax, _ = mpl.colorbar.make_axes( ax, aspect=100) # new matplotlib feature _ = mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=kwargs["cmap"], label="probability") for ax in axes[i + 1:]: # noqa ax.remove() return fig def plot_paga_pie(): colors = list(adata.obsm[lk][:, lin_names].colors) colors = { i: odict(zip(colors, mean)) for i, (mean, _) in enumerate(d.values()) } fig, ax = plt.subplots(figsize=figsize, dpi=dpi) kwargs["ax"] = ax kwargs["show"] = False kwargs["colorbar"] = False # has to be disabled kwargs["show"] = False kwargs["node_colors"] = colors kwargs.pop("save", None) # we will handle saving kwargs["transitions"] = kwargs.get("transitions", None) if "legend_loc" in kwargs: orig_ll = kwargs["legend_loc"] if orig_ll != "on data": kwargs["legend_loc"] = "none" # we will handle legend else: orig_ll = None kwargs["legend_loc"] = "on data" if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key ax = scv.pl.paga(adata, **kwargs) ax.set_title(kwargs.get("title", cluster_key)) if basis is not None and orig_ll not in ("none", "on data", None): handles = [] for cluster_name, color in zip( adata.obs[f"{cluster_key}"].cat.categories, adata.uns[f"{cluster_key}_colors"], ): handles += [ax.scatter([], [], label=cluster_name, c=color)] first_legend = _position_legend( ax, legend_loc=orig_ll, handles=handles, **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=cluster_key, ) fig.add_artist(first_legend) if legend_kwargs.get("loc", None) not in ("none", "on data", None): # we need to use these, because scvelo can have its own handles and # they would be plotted here handles = [] for lineage_name, color in zip(lin_names, colors[0].keys()): handles += [ax.scatter([], [], label=lineage_name, c=color)] if len(colors[0].keys()) != len(adata.obsm[lk].names): handles += [ax.scatter([], [], label="Rest", c="grey")] second_legend = _position_legend( ax, legend_loc=legend_kwargs["loc"], handles=handles, **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=points, ) fig.add_artist(second_legend) return fig def plot_violin(): kwargs.pop("ax", None) kwargs.pop("keys", None) kwargs.pop("save", None) # we will handle saving kwargs["show"] = False kwargs["groupby"] = cluster_key kwargs["rotation"] = xrot data = adata.obsm[lk] to_clean = [] for name in lin_names: # TODO: once ylabel is implemented, the prefix isn't necessary key = f"{dir_prefix} {name}" if key not in adata.obs_keys(): to_clean.append(key) adata.obs[key] = np.array( data[:, name]) # TODO: better approach - dummy adata cols = len(lin_names) if ncols is None else ncols nrows = ceil(len(lin_names) / cols) fig, axes = plt.subplots( nrows, cols, figsize=(6 * cols, 4 * nrows) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) if not isinstance(axes, np.ndarray): axes = [axes] axes = np.ravel(axes) i = 0 for i, (name, ax) in enumerate(zip(lin_names, axes)): key = f"{dir_prefix} {name}" ax.set_title(key) sc.pl.violin(adata, ylabel="" if i else "probability", keys=key, ax=ax, **kwargs) for ax in axes[i + 1:]: # noqa ax.remove() for name in to_clean: del adata.obs[name] return fig def plot_violin_no_cluster_key(): kwargs.pop("ax", None) kwargs.pop("keys", None) # don't care kwargs.pop("save", None) kwargs["show"] = False kwargs["groupby"] = points kwargs["xlabel"] = None kwargs["rotation"] = xrot data = np.ravel(np.array(adata.obsm[lk]).T)[..., np.newaxis] dadata = AnnData(np.zeros_like(data)) dadata.obs["probability"] = data dadata.obs[points] = (pd.Series( np.ravel([[f"{dir_prefix.lower()} {n}"] * adata.n_obs for n in adata.obsm[lk].names ])).astype("category").values) dadata.uns[f"{points}_colors"] = adata.obsm[lk].colors fig, ax = plt.subplots(figsize=figsize if figsize is not None else (8, 6), dpi=dpi) ax.set_title(points.capitalize()) sc.pl.violin(dadata, keys=["probability"], ax=ax, **kwargs) return fig def plot_heatmap(): title = kwargs.pop("title", None) if not title: title = "average fate per cluster" data = pd.DataFrame([mean for mean, _ in d.values()], columns=lin_names, index=clusters).T if "cmap" not in kwargs: kwargs["cmap"] = "viridis" if use_clustermap: kwargs["cbar_pos"] = (0, 0.9, 0.025, 0.15) if show_cbar else None max_size = float(max(data.shape)) g = clustermap( data, robust=True, annot=True, fmt=".2f", row_colors=adata.obsm[lk][lin_names].colors, dendrogram_ratio=( 0.15 * data.shape[0] / max_size, 0.15 * data.shape[1] / max_size, ), figsize=figsize, **kwargs, ) g.ax_heatmap.set_xlabel(cluster_key) g.ax_heatmap.set_ylabel("lineage") g.ax_col_dendrogram.set_title(title) fig = g.fig g = g.ax_heatmap else: fig, ax = plt.subplots(figsize=figsize, dpi=dpi) g = heatmap( data, robust=True, annot=True, fmt=".2f", cbar=show_cbar, ax=ax, **kwargs, ) ax.set_title(title) ax.set_xlabel(cluster_key) ax.set_ylabel("lineage") g.set_xticklabels(g.get_xticklabels(), rotation=xrot) g.set_yticklabels(g.get_yticklabels(), rotation=0) return fig if mode not in _cluster_fates_modes: raise ValueError( f"Invalid mode: `{mode!r}`. Valid options are: `{_cluster_fates_modes}`." ) if cluster_key is not None: if cluster_key not in adata.obs: raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") elif mode not in ("bar", "violin"): raise ValueError( f"Not specifying cluster key is only available for modes `'bar'` and `'violin'`, found `mode={mode!r}`." ) lk = str(LinKey.FORWARD if final else LinKey.BACKWARD) points = "final states" if final else "root states" dir_prefix = "To" if final else "From" if cluster_key is not None: is_all = False if clusters is not None: if isinstance(clusters, str): clusters = [clusters] clusters = _make_unique(clusters) if mode in ("paga", "paga_pie"): logg.debug( f"DEBUG: Setting `clusters` to all available ones because of `mode={mode!r}`" ) clusters = list(adata.obs[cluster_key].cat.categories) else: for cname in clusters: if cname not in adata.obs[cluster_key].cat.categories: raise KeyError( f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`" ) else: clusters = list(adata.obs[cluster_key].cat.categories) else: is_all = True clusters = [points] if lk not in adata.obsm: raise KeyError(f"Lineages key `{lk!r}` not found in `adata.obsm`.") if lineages is not None: if isinstance(lineages, str): lineages = [lineages] lineages = _make_unique(lineages) for ep in lineages: if ep not in adata.obsm[lk].names: raise ValueError( f"Endpoint `{ep!r}` not found in `adata.obsm[{lk!r}].names`." ) lin_names = list(lineages) else: # must be list for sc.pl.violin, else cats str lin_names = list(adata.obsm[lk].names) if mode == "violin" and not is_all: adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy() d = odict() for name in clusters: mask = (np.ones((adata.n_obs, ), dtype=np.bool) if is_all else (adata.obs[cluster_key] == name).values) mask = list(np.array(mask, dtype=np.bool)) data = adata.obsm[lk][mask, lin_names].X mean = np.nanmean(data, axis=0) std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0]) d[name] = [mean, std] has_xrot = "xticks_rotation" in kwargs xrot = kwargs.pop("xticks_rotation", 45) logg.debug(f"DEBUG: Using mode: `{mode!r}`") use_clustermap = mode == "clustermap" if use_clustermap: mode = "heatmap" if mode == "bar": fig = plot_bar() elif mode == "paga": if "paga" not in adata.uns: raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.") fig = plot_paga() elif mode == "paga_pie": if "paga" not in adata.uns: raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.") fig = plot_paga_pie() elif mode == "violin": fig = plot_violin_no_cluster_key( ) if cluster_key is None else plot_violin() elif mode == "heatmap": fig = plot_heatmap() else: raise ValueError( f"Invalid mode `{mode!r}`. Valid options are: `{_cluster_fates_modes}`." ) if save is not None: save_fig(fig, save) fig.show()
def graph( data: Union[AnnData, np.ndarray, spmatrix], graph_key: Optional[str] = None, ixs: Optional[np.array] = None, layout: Union[str, Dict, Callable] = nx.kamada_kawai_layout, keys: Sequence[KEYS] = ("incoming", ), keylocs: Union[KEYLOCS, Sequence[KEYLOCS]] = "uns", node_size: float = 400, labels: Optional[Union[Sequence[str], Sequence[Sequence[str]]]] = None, top_n_edges: Optional[Union[int, Tuple[int, bool, str]]] = None, self_loops: bool = True, self_loop_radius_frac: Optional[float] = None, filter_edges: Optional[Tuple[float, float]] = None, edge_reductions: Union[Callable, Sequence[Callable]] = np.sum, edge_weight_scale: float = 10, edge_width_limit: Optional[float] = None, edge_alpha: float = 1.0, edge_normalize: bool = False, edge_use_curved: bool = True, show_arrows: bool = True, font_size: int = 12, font_color: str = "black", cat_cmap: ListedColormap = cm.Set3, cont_cmap: ListedColormap = cm.viridis, legend_loc: Optional[str] = "best", figsize: Tuple[float, float] = (15, 10), dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, layout_kwargs: Dict = MappingProxyType({}), ) -> None: """ Plot a graph, visualizing incoming and outgoing edges or self-transitions. This is a utility function to look in more detail at the transition matrix in areas of interest, e.g. around an endpoint of development. This function is meant to visualise a small subset of nodes (~100-500) and the most likely transitions between them. Note that limiting edges visualized using :paramref:`top_n_edges` will speed things up, as well as reduce the visual clutter. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/graph.png :width: 400px :align: center Params ------ data : The graph data, stored either in `.uns` [ :paramref:`graph_key` ], or as a sparse or a dense matrix. graph_key Key in :paramref:`adata` `.uns` where the graph is stored. Only used when :paramref:`adata` is :class:`Anndata` object. ixs Subset of indices of the graph to visualize. layout Layout to use for graph drawing. - If :class:`str`, search for embedding in :paramref:`adata` `.obsm[X_` :paramref:`layout` `]`. Use :paramref:`layout_kwargs` = `{'components': [0, 1]}` to select components. - If :class:`dict`, keys should be values in interval [0, len(:paramref:`ixs`)) and values `(x, y)` pairs corresponding to node positions. keys Keys in :paramref:`adata` `.obs`, :paramref:`adata` `.obsm` or :paramref:`adata` `.uns` to color the nodes. - If `'incoming'`, `'outgoing'` or `'self_loops'` to visualize reduction (see :paramref:`edge_reductions`) for each node based on incoming or outgoing edges, respectively. keylocs Locations of :paramref:`keys`, can be `'obs'`, `'obsm'` or `'uns'`. node_size Size of the nodes. labels Labels of the nodes. top_n_edges Either top N outgoing edges in descending order or a tuple `(top_n_edges, in_ascending_order, {'incoming', 'outgoing'})`. If `None`, show all edges. self_loops Whether visualize self transitions and also to consider them in :paramref:`top_n_edges`. self_loop_radius_frac Fraction of a unit circle to visualize self transitions. If `None`, use :paramref:`node_size` / 1000. filter_edges Whether to remove all edges not in `[min, max]` interval. edge_reductions Aggregation function to use when coloring nodes by edge weights. edge_weight_scale Number by which to scale the width of the edges. Useful when the weights are small. edge_width_limit Upper bound for the width of the edges. Useful when weights are unevenly distributed. edge_alpha Alpha channel value for edges and arrows. edge_normalize If `True`, normalize edges to `[0, 1]` interval prior to applying any scaling or truncation. edge_use_curved If `True`, use curved edges. This can improve visualization at a small performance cost. show_arrows Whether to show the arrows. Setting this to `False` may dramatically speed things up. font_size Font size for node labels. font_color Label color of the nodes. cat_cmap Categorical colormap used when :paramref:`keys` contain categorical variables. cont_cmap Continuous colormap used when :paramref:`keys` contain continuous variables. legend_loc Location of the legend. figsize Size of the figure. dpi Dots per inch. save Filename where to save the plots. If `None`, just shows the plot. layout_kwargs Additional kwargs for :paramref:`layout`. Returns ------- None Nothing, just plots the graph. Optionally saves the figure based on :paramref:`save`. """ def plot_arrows(curves, G, pos, ax, edge_weight_scale): for line, (edge, val) in zip(curves, G.edges.items()): if edge[0] == edge[1]: continue mask = (~np.isnan(line)).all(axis=1) line = line[mask, :] if not len(line): # can be all NaNs continue line = line.reshape((-1, 2)) X, Y = line[:, 0], line[:, 1] node_start = pos[edge[0]] # reverse if np.where(np.isclose(node_start - line, [0, 0]).all(axis=1))[0][0]: X, Y = X[::-1], Y[::-1] mid = len(X) // 2 posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2]) # noqa arrow = FancyArrowPatch( posA=posA, posB=posB, # we clip because too small values # cause it to crash arrowstyle=ArrowStyle.CurveFilledB( head_length=np.clip( val["weight"] * edge_weight_scale * 4, _min_edge_weight, edge_width_limit, ), head_width=np.clip( val["weight"] * edge_weight_scale * 2, _min_edge_weight, edge_width_limit, ), ), color="k", zorder=float("inf"), alpha=edge_alpha, linewidth=0, ) ax.add_artist(arrow) def normalize_weights(): weights = np.array([v["weight"] for v in G.edges.values()]) minn = np.min(weights) weights = (weights - minn) / (np.max(weights) - minn) for v, w in zip(G.edges.values(), weights): v["weight"] = w def remove_top_n_edges(): if top_n_edges is None: return if isinstance(top_n_edges, (tuple, list)): to_keep, ascending, group_by = top_n_edges else: to_keep, ascending, group_by = top_n_edges, False, "out" if group_by not in ("incoming", "outgoing"): raise ValueError( "Argument `groupby` in `top_n_edges` must be either `'incoming`' or `'outgoing'`." ) source, target = zip(*G.edges) weights = [v["weight"] for v in G.edges.values()] tmp = pd.DataFrame({ "outgoing": source, "incoming": target, "w": weights }) if not self_loops: # remove self loops tmp = tmp[tmp["incoming"] != tmp["outgoing"]] to_keep = set( map( tuple, tmp.groupby(group_by).apply( lambda g: g.sort_values("w", ascending=ascending).take( range(min(to_keep, len(g)))))[["outgoing", "incoming"]].values, )) for e in list(G.edges): if e not in to_keep: G.remove_edge(*e) def remove_low_weight_edges(): if filter_edges is None or filter_edges == (None, None): return minn, maxx = filter_edges minn = minn if minn is not None else -np.inf maxx = maxx if maxx is not None else np.inf for e, attr in list(G.edges.items()): if attr["weight"] < minn or attr["weight"] > maxx: G.remove_edge(*e) _min_edge_weight = 0.00001 if edge_width_limit is None: logg.debug("DEBUG: Not limiting width of edges") edge_width_limit = float("inf") if self_loop_radius_frac is None: self_loop_radius_frac = (node_size / 2000 if node_size >= 200 else node_size / 1000) logg.debug( f"Setting self loop radius fraction to `{self_loop_radius_frac}`") if not isinstance(keylocs, (tuple, list)): keylocs = [keylocs] * len(keys) elif len(keylocs) == 1: keylocs = keylocs * 3 elif all(map(lambda k: k in ("incoming", "outgoing", "self_loops"), keys)): # don't care about keylocs since they are irrelevant logg.debug("DEBUG: Ignoring key locations") keylocs = [None] * len(keys) for k in ("obs", "obsm"): if k in keylocs and ixs is None: raise ValueError( f"Invalid combination: `ixs` is None and found `{k!r}` in `keylocs`." ) if not isinstance(edge_reductions, (tuple, list)): edge_reductions = [edge_reductions] * len(keys) if not all(map(callable, edge_reductions)): raise ValueError("Not all edge_reductions functions are callable.") if not isinstance(labels, (tuple, list)): labels = [labels] * len(keys) elif not len(labels): labels = [None] * len(keys) elif not isinstance(labels[0], (tuple, list)): labels = [labels] * len(keys) if len(labels) != len(keys): raise ValueError("`Keys` and `labels` must be of the same shape.") if isinstance(data, AnnData): if graph_key is None: raise ValueError( "Argument `graph_key` cannot be `None` when `adata` is `anndata.Anndata` object." ) gdata = data.uns[graph_key]["T"] elif isinstance(data, (np.ndarray, spmatrix)): gdata = data else: raise TypeError( f"Expected argument `data` to be one of `AnnData`, `numpy.ndarray`, `scipy.sparse.spmatrix`, " f"found `{type(data).__name__}`") is_sparse = issparse(gdata) if ixs is not None: gdata = gdata[ixs, :][:, ixs] else: ixs = list(range(gdata.shape[0])) start = logg.info("Creating graph") G = (nx.from_scipy_sparse_matrix(gdata, create_using=nx.DiGraph) if is_sparse else nx.from_numpy_array(gdata, create_using=nx.DiGraph)) remove_low_weight_edges() remove_top_n_edges() if edge_normalize: normalize_weights() logg.info(" Finish", time=start) # do NOT recreate the graph, for the edge reductions # gdata = nx.to_numpy_array(G) fig, axes = plt.subplots(nrows=len(keys), ncols=1, figsize=figsize, dpi=dpi) if not isinstance(axes, np.ndarray): axes = np.array([axes]) axes = np.ravel(axes) if isinstance(layout, str): if f"X_{layout}" not in data.obsm: raise KeyError( f"Unable to find embedding `'X_{layout}'` in `adata.obsm`.") components = layout_kwargs.get("components", [0, 1]) if len(components) != 2: raise ValueError( f"Components in `layout_kwargs` must be of length `2`, found `{len(components)}`." ) emb = data.obsm[f"X_{layout}"][:, components] pos = {i: emb[ix, :] for i, ix in enumerate(ixs)} logg.info(f"Embedding graph using `{layout!r}` layout") elif isinstance(layout, dict): rng = range(len(ixs)) for k, v in layout.items(): if k not in rng: raise ValueError( f"Key in `layout` must be in `range(len(ixs))`, found `{k}`." ) if len(v) != 2: raise ValueError( f"Value in `layout` must be a tuple or list of length 2, found `{v}`." ) pos = layout logg.debug("DEBUG: Using pre-specified layout") elif callable(layout): start = logg.info( f"Embedding graph using `{layout.__name__!r}` layout") pos = layout(G, **layout_kwargs) logg.info(" Finish", time=start) else: raise TypeError(f"Argument `layout` must be either a `string`, " f"a `dict` or a `callable`, found `{type(layout)}`.") curves, lc = None, None if edge_use_curved: try: from ._utils import curved_edges logg.debug("DEBUG: Creating curved edges") curves = curved_edges(G, pos, self_loop_radius_frac, polarity="directed") lc = LineCollection( curves, colors="black", linewidths=np.clip( np.ravel([v["weight"] for v in G.edges.values()]) * edge_weight_scale, 0, edge_width_limit, ), alpha=edge_alpha, ) except ImportError as e: global _msg_shown if not _msg_shown: print( str(e)[:-1], "in order to use curved edges or specify `edge_use_curved=False`.", ) _msg_shown = True for ax, keyloc, key, labs, er in zip(axes, keylocs, keys, labels, edge_reductions): label_col = {} # dummy value if key in ("incoming", "outgoing", "self_loops"): if key in ("incoming", "outgoing"): vals = np.array(er(gdata, axis=int(key == "outgoing"))).flatten() else: vals = gdata.diagonal() if is_sparse else np.diag(gdata) node_v = dict(zip(pos.keys(), vals)) else: label_col = getattr(data, keyloc) if key in label_col: node_v = dict(zip(pos.keys(), label_col[key])) else: raise RuntimeError( f"Key `{key!r}` not found in `adata.{keyloc!r}`.") if labs is not None: if len(labs) != len(pos): raise RuntimeError( f"Number of labels ({len(labels)}) and nodes ({len(pos)}) mismatch." ) nx.draw_networkx_labels( G, pos, labels=labs if isinstance(labs, dict) else dict( zip(pos.keys(), labs)), ax=ax, font_color=font_color, font_size=font_size, ) if lc is not None and curves is not None: ax.add_collection(deepcopy(lc)) # copying necessary if show_arrows: plot_arrows(curves, G, pos, ax, edge_weight_scale) else: nx.draw_networkx_edges( G, pos, width=[ np.clip( v["weight"] * edge_weight_scale, _min_edge_weight, edge_width_limit, ) for _, v in G.edges.items() ], alpha=edge_alpha, edge_color="black", arrows=True, arrowstyle="-|>", ) if key in label_col and is_categorical_dtype(label_col[key]): values = label_col[key] if keyloc in ("obs", "obsm"): values = values[ixs] categories = values.cat.categories color_key = _colors(key) if color_key in data.uns: mapper = dict(zip(categories, data.uns[color_key])) else: mapper = dict( zip(categories, map(cat_cmap.get, range(len(categories))))) colors = [] seen = set() for v in values: colors.append(mapper[v]) seen.add(v) nodes_kwargs = dict(cmap=cat_cmap, node_color=colors) # noqa if legend_loc is not None: x, y = pos[0] for label in sorted(seen): ax.plot([x], [y], label=label, color=mapper[label]) ax.legend(loc=legend_loc) else: values = list(node_v.values()) vmin, vmax = np.min(values), np.max(values) nodes_kwargs = dict( # noqa cmap=cont_cmap, node_color=values, vmin=vmin, vmax=vmax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="1.5%", pad=0.05) _ = mpl.colorbar.ColorbarBase(cax, cmap=cont_cmap, norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax)) nx.draw_networkx_nodes(G, pos, node_size=node_size, ax=ax, **nodes_kwargs) ax.set_title(key) ax.axis("off") if save is not None: save_fig(fig, save) fig.show()
def cluster_lineage( adata: AnnData, model: _model_type, genes: Sequence[str], lineage: str, final: bool = True, clusters: Optional[Sequence[str]] = None, n_points: int = 200, time_key: str = "latent_time", cluster_key: str = "louvain", norm: bool = True, recompute: bool = False, ncols: int = 3, sharey: bool = False, n_jobs: Optional[int] = 1, backend: str = "multiprocessing", pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}), neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}), louvain_kwargs: Dict = MappingProxyType({}), key_added: Optional[str] = None, save: Optional[Union[str, Path]] = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, show_progress_bar: bool = True, **kwargs, ) -> None: """ Cluster gene expression trends within a lineage and plot the clusters. This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive development along a given lineage. Consider running this function on a subset of genes which are potential lineage drivers, identified e.g. by running :func:`cellrank.tl.gene_importance`. .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/cluster_lineage.png :width: 400px :align: center Params ------ adata : :class:`anndata.AnnData` Annotated data object. model Model to fit. - If a :class:`dict`, gene and lineage specific models can be specified. Use `'*'` to indicate all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`. genes Genes in :paramref:`adata`.var_names to cluster. lineage_name Name of the lineage along which to cluster the genes. final Whether to consider cells going to final states or vice versa. clusters Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when plotting previously computed clusters. n_points Number of points used for prediction. time_key Key in :paramref:`adata` `.obs` where the pseudotime is stored. cluster_key Key in :paramref:`adata` `.obs` where the clustering is stored. norm Whether to z-normalize each trend to have `0` mean, `1` variance. recompute If `True`, recompute the clustering, otherwise try to find already existing one. ncols Number of columns for the plot. sharey Whether to share y-axis across multiple plots. n_jobs Number of parallel jobs. If `-1`, use all available cores. If `None` or `1`, the execution is sequential. backend Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. pca_kwargs Keyword arguments for :func:`scanpy.pp.pca`. neighbors_kwargs Keyword arguments for :func:`scanpy.pp.neighbors`. louvain_kwargs Keyword arguments for :func:`scanpy.tl.louvain`. save Filename where to save the plot. If `None`, just shows the plot. figsize Size of the figure. If `None`, it will be set automatically. dpi Dots per inch. show_progress_bar Whether to show a progress bar tracking models fitted. kwargs: Keyword arguments for :meth:`cellrank.ul.models.Model.prepare`. Returns ------- None Plots the clusters of :paramref:`genes` for the given :paramref:`lineage_name`. Optionally saves the figure based on :paramref:`save`. Updates :paramref:`adata` `.uns` with the following key: lineage_{:paramref:`lineage_name`}_trend_{:paramref:`key_added`}: - :class:`anndata.AnnData` object of shape `len` (:paramref:`genes`) x :paramref:`n_points` containing the clustered genes. """ lineage_key = str(LinKey.FORWARD if final else LinKey.BACKWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") _ = adata.obsm[lineage_key][lineage] check_collection(adata, genes, "var_names") key_to_add = f"lineage_{lineage}_trend" if key_added is not None: logg.debug(f"DEBUG: Adding key `{key_added!r}`") key_to_add += f"_{key_added}" if recompute or key_to_add not in adata.uns: kwargs["time_key"] = time_key # kwargs for the model.prepare kwargs["n_test_points"] = n_points kwargs["final"] = final models = _create_models(model, genes, [lineage]) if _is_any_gam_mgcv(models): backend = "multiprocessing" n_jobs = _get_n_cores(n_jobs, len(genes)) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") trends = parallelize( _cl_process, genes, as_array=True, unit="gene", n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )(models, lineage, norm, **kwargs) logg.info(" Finish", time=start) trends = AnnData(np.vstack(trends)) trends.obs_names = genes # sanity check if trends.n_obs != len(genes): raise RuntimeError( f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`." ) if n_points is not None and trends.n_vars != n_points: raise RuntimeError( f"Expected to find `{n_points}` points, found `{trends.n_vars}`." ) pca_kwargs = dict(pca_kwargs) n_comps = pca_kwargs.pop("n_comps", 50) # default value if n_comps > len(genes): n_comps = len(genes) - 1 sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs) sc.pp.neighbors(trends, **neighbors_kwargs) louvain_kwargs = dict(louvain_kwargs) louvain_kwargs["key_added"] = cluster_key sc.tl.louvain(trends, **louvain_kwargs) adata.uns[key_to_add] = trends else: logg.info(f"Loading data from `adata.uns[{key_to_add}!r]`") trends = adata.uns[key_to_add] if clusters is None: if cluster_key not in trends.obs: raise KeyError(f"Invalid cluster key `{cluster_key!r}`.") clusters = trends.obs[cluster_key].cat.categories nrows = int(np.ceil(len(clusters) / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * 10, nrows * 10) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) j = 0 for j, (ax, c) in enumerate(zip(axes, clusters)): # noqa data = trends[trends.obs[cluster_key] == c].X mean, sd = np.mean(data, axis=0), np.var(data, axis=0) sd = np.sqrt(sd) for i in range(data.shape[0]): ax.plot(data[i], color="gray", lw=0.5) ax.plot(mean, lw=2, color="black") ax.plot(mean - sd, lw=1.5, color="black", linestyle="--") ax.plot(mean + sd, lw=1.5, color="black", linestyle="--") ax.fill_between(range(len(mean)), mean - sd, mean + sd, color="black", alpha=0.1) ax.set_title(f"Cluster {c}") ax.set_xticks([]) if not sharey: ax.set_yticks([]) for j in range(j + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save)
def plot( self, figsize: Tuple[float, float] = (15, 10), same_plot: bool = False, hide_cells: bool = False, perc: Tuple[float, float] = None, abs_prob_cmap: mcolors.ListedColormap = cm.viridis, cell_color: str = "black", color: str = "black", alpha: float = 0.8, lineage_alpha: float = 0.2, title: Optional[str] = None, size: int = 15, lw: float = 2, show_cbar: bool = True, margins: float = 0.015, xlabel: str = "Pseudotime", ylabel: str = "Expression", show_conf_int: bool = True, dpi: int = None, fig: mpl.figure.Figure = None, ax: mpl.axes.Axes = None, return_fig: bool = False, save: Optional[str] = None, ) -> Optional[mpl.figure.Figure]: """ Plots the smoothed gene expression. Params ------ figsize Size of the figure. same_plot Whether to plot all trends in the same plot. hide_cells Whether to hide the cells. perc Percentile by which to clip the absorption probabilities. abs_prob_cmap Colormap to use when coloring in the absorption probabilities. cell_color Color for the cells when not coloring absorption probabilities. color Color for the lineages. alpha Alpha channel for cells. lineage_alpha Alpha channel for lineage confidence intervals. title Title of the plot. size Size of the points. lw Line width for the smoothed values. show_cbar Whether to show colorbar. margins Margins around the plot. xlabel Label on the x-axis. ylabel Label on the y-axis. show_conf_int Whether to show the confidence interval. dpi Dots per inch. fig Figure to use, if `None`, create a new one. ax: :class:`matplotlib.axes.Axes` Ax to use, if `None`, create a new one. return_fig If `True`, return the figure object. save Filename where to save the plot. If `None`, just shows the plots. Returns ------- None Nothing, just plots the fitted model. """ if fig is None or ax is None: fig, ax = plt.subplots(figsize=figsize, constrained_layout=True) if dpi is not None: fig.set_dpi(dpi) vmin, vmax = _minmax(self.w, perc) if not hide_cells: _ = ax.scatter( self.x_all.squeeze(), self.y_all.squeeze(), c=cell_color if same_plot or np.allclose(self.w_all, 1.0) else self.w_all.squeeze(), s=size, cmap=abs_prob_cmap, vmin=vmin, vmax=vmax, alpha=alpha, ) ax.plot(self.x_test, self.y_test, color=color, lw=lw, label=title) ax.set_title(title) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.margins(margins) if show_conf_int and self.conf_int is not None: ax.fill_between( self.x_test.squeeze(), self.conf_int[:, 0], self.conf_int[:, 1], alpha=lineage_alpha, color=color, linestyle="--", ) if show_cbar and not hide_cells and not same_plot: norm = mcolors.Normalize(vmin=vmin, vmax=vmax) cax, _ = mpl.colorbar.make_axes(ax, aspect=200) _ = mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=abs_prob_cmap, label="Absorption probability") if save is not None: save_fig(fig, save) if return_fig: return fig
def _plot_real_spectrum( self, dpi: int = 100, figsize: Optional[Tuple[float, float]] = None, legend_loc: Optional[str] = None, title: Optional[str] = None, save: Optional[Union[str, Path]] = None, ) -> None: """ Plot the real part of the top eigenvalues. Params ------ dpi Dots per inch. figsize Size of the figure. legend_loc Location parameter for the legend. title Figure title. save Filename where to save the plots. If `None`, just shows the plot. Returns ------- None Nothing, just plots the spectrum. """ if self._eig is None: logg.warning( "No eigendecomposition found, computing with default parameters" ) self.compute_eig() # Obtain the eigendecomposition, create the color code D, params = self._eig["D"], self._eig["params"] D_real, D_imag = D.real, D.imag ixs = np.arange(len(D)) mask = D_imag == 0 # plot the top eigenvalues fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize) ax.scatter(ixs[mask], D_real[mask], marker="o", label="real eigenvalue") ax.scatter(ixs[~mask], D_real[~mask], marker="o", label="complex eigenvalue") # add dashed line for the eigengap, ticks, labels, title and legend ax.axvline(self._eig["eigengap"], label="eigengap", ls="--") ax.set_xlabel("index") ax.set_xticks(range(len(D))) ax.set_ylabel(r"Re($\lambda_i$)") key = "real part" if params["which"] == "LR" else "magnitude" # set the title if title is None: fig_title = ( f"real part of top {params['k']} eigenvalues according to their {key}" ) else: fig_title = title ax.set_title(fig_title) ax.legend(loc=legend_loc) if save is not None: save_fig(fig, save) fig.show()
def plot_spectrum( self, real_only: bool = False, dpi: int = 100, figsize: Optional[Tuple[float, float]] = (5, 5), legend_loc: Optional[str] = None, title: Optional[str] = None, save: Optional[Union[str, Path]] = None, ) -> None: """ Plot the top eigenvalues in complex plane. Params ------ real_only Whether to plot only the real part of the spectrum. dpi Dots per inch. figsize Size of the figure. legend_loc Location parameter for the legend title Figure title save Filename where to save the plots. If `None`, just shows the plot. Returns ------- None Nothing, just plots the spectrum in complex plane. """ # define a function to make the data limits rectangular def adapt_range(min_, max_, range_): return ( min_ + (max_ - min_) / 2 - range_ / 2, min_ + (max_ - min_) / 2 + range_ / 2, ) if self._eig is None: logg.warning( "No eigendecomposition found, computing with default parameters" ) self.compute_eig() if real_only: self._plot_real_spectrum(dpi=dpi, figsize=figsize, legend_loc=legend_loc, save=save, title=title) return D = self._eig["D"] params = self._eig["params"] # create fiture and axes fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize) # get the original data ranges lam_x, lam_y = D.real, D.imag x_min, x_max = np.min(lam_x), np.max(lam_x) y_min, y_max = np.min(lam_y), np.max(lam_y) x_range, y_range = x_max - x_min, y_max - y_min final_range = np.max([x_range, y_range]) + 0.05 x_min_, x_max_ = adapt_range(x_min, x_max, final_range) y_min_, y_max_ = adapt_range(y_min, y_max, final_range) # plot the data and the unit circle ax.scatter(D.real, D.imag, marker=".", label="eigenvalue") t = np.linspace(0, 2 * np.pi, 500) x_circle, y_circle = np.sin(t), np.cos(t) ax.plot(x_circle, y_circle, "k-", label="unit circle") # set labels, ranges and legend ax.set_xlabel(r"Re($\lambda$)") ax.set_xlim(x_min_, x_max_) ax.set_ylabel(r"Im($\lambda$)") ax.set_ylim(y_min_, y_max_) key = "real part" if params["which"] == "LR" else "magnitude" # set the figure title if title is None: fig_title = f"top {params['k']} eigenvalues according to their {key}" else: fig_title = title ax.set_title(fig_title) # set legend location ax.legend(loc=legend_loc) if save is not None: save_fig(fig, save) fig.show()