コード例 #1
0
ファイル: _utils.py プロジェクト: dpeerlab/cellrank
def composition(
    adata: AnnData,
    key,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[float] = None,
    save: Optional[Union[str, Path]] = None,
) -> None:
    """
    Plot pie chart for categorical annotation.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/composition.png
       :width: 400px
       :align: center

    Params
    ------
    adata
        Annotated data object.
    key
        Key in :paramref:`adata` `.obs` containing categorical observation.
    figsize
        Size of the figure.
    dpi
        Dots per inch.
    save
        Filename where to save the plots.
        If `None`, just shows the plot.

    Returns
    -------
    None
        Nothing, just plots the similarity matrix.
        Optionally saves the figure based on :paramref:`save`.
    """

    if key not in adata.obs:
        raise KeyError(f"Key `{key!r}` not found in `adata.obs`.")
    if not is_categorical_dtype(adata.obs[key]):
        raise TypeError(
            f"Observation `adata.obs[{key!r}]` is not categorical.")

    cats = adata.obs[key].cat.categories
    colors = adata.uns.get(f"{key}_colors", None)
    x = [np.sum(adata.obs[key] == cl) for cl in cats]
    cats_frac = x / np.sum(x)

    # plot these fractions in a pie plot
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    ax.pie(x=cats_frac, labels=cats, colors=colors)
    ax.set_title(f"Composition by {key}")

    if save is not None:
        save_fig(fig, save)

    fig.show()
コード例 #2
0
ファイル: _utils.py プロジェクト: dpeerlab/cellrank
def _trends_helper(
    adata: AnnData,
    models: Dict[str, Dict[str, Any]],
    gene: str,
    ln_key: str,
    lineage_names: Optional[Sequence[str]] = None,
    same_plot: bool = False,
    sharey: bool = True,
    cmap=None,
    fig: mpl.figure.Figure = None,
    ax: mpl.axes.Axes = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> None:
    """
    Plot an expression gene for some lineages.

    Params
    ------
    adata: :class:`anndata.AnnData`
        Annotated data object.
    models
        Gene and lineage specific models can be specified. Use `'*'` to indicate
        all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`.
    gene
        Name of the gene in `adata.var_names`.
    fig
        Figure to use, if `None`, create a new one.
    ax
        Ax to use, if `None`, create a new one.
    save
        Filename where to save the plot.
        If `None`, just shows the plots.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.Model.plot`.

    Returns
    -------
    None
        Nothing, just plots the trends.
        Optionally saves the figure based on :paramref:`save`.
    """

    n_lineages = len(lineage_names)
    if same_plot:
        if fig is None and ax is None:
            fig, ax = plt.subplots(
                1,
                figsize=kwargs.get("figsize", None) or (15, 10),
                constrained_layout=True,
            )
        axes = [ax] * len(lineage_names)
    else:
        fig, axes = plt.subplots(
            ncols=n_lineages,
            figsize=kwargs.get("figsize", None) or (6 * n_lineages, 6),
            sharey=sharey,
            constrained_layout=True,
        )
    axes = np.ravel(axes)
    percs = kwargs.pop("perc", None)
    if percs is None or not isinstance(percs[0], (tuple, list)):
        percs = [percs]

    same_perc = False  # we need to show colorbar always if percs differ
    if len(percs) != n_lineages or n_lineages == 1:
        if len(percs) != 1:
            raise ValueError(
                f"Percentile must be a collection of size `1` or `{n_lineages}`, got `{len(percs)}`."
            )
        same_perc = True
        percs = percs * n_lineages

    hide_cells = kwargs.pop("hide_cells", False)
    show_cbar = kwargs.pop("show_cbar", True)
    lineage_color = kwargs.pop("color", "black")

    lc = (cmap.colors if cmap is not None and hasattr(cmap, "colors") else
          adata.uns.get(f"{_colors(ln_key)}", cm.Set1.colors))

    for i, (name, ax, perc) in enumerate(zip(lineage_names, axes, percs)):
        title = name if name is not None else "No lineage"
        models[gene][name].plot(
            ax=ax,
            fig=fig,
            perc=perc,
            show_cbar=True if not same_perc else False if not show_cbar else
            (i == n_lineages - 1),
            title=title,
            hide_cells=hide_cells or (same_plot and i != n_lineages - 1),
            same_plot=same_plot,
            color=lc[i] if same_plot and name is not None else lineage_color,
            ylabel=gene if not same_plot or name is None else "expression",
            **kwargs,
        )

    if same_plot and lineage_names != [None]:
        ax.set_title(gene)
        ax.legend()

    if save is not None:
        save_fig(fig, save)
コード例 #3
0
ファイル: _heatmap.py プロジェクト: opnumten/cellrank
def heatmap(
    adata: AnnData,
    model: _model_type,
    genes: Sequence[str],
    final: bool = True,
    kind: str = "lineages",
    lineages: Optional[Union[str, Sequence[str]]] = None,
    start_lineage: Optional[Union[str, Sequence[str]]] = None,
    end_lineage: Optional[Union[str, Sequence[str]]] = None,
    lineage_height: float = 0.1,
    cluster_genes: bool = False,
    xlabel: Optional[str] = None,
    cmap: colors.ListedColormap = cm.Spectral_r,
    n_jobs: Optional[int] = 1,
    backend: str = "multiprocessing",
    hspace: float = 0.25,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    show_progress_bar: bool = True,
    **kwargs,
) -> None:
    """
    Plot a heatmap of smoothed gene expression along specified lineages.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/heatmap.png
       :width: 400px
       :align: center

    Params
    ------
    adata : :class:`anndata.AnnData`
        Annotated data object.
    model
        Model to fit.

        - If a :class:`dict`, gene and lineage specific models can be specified. Use `'*'` to indicate
        all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`.
    genes
        Genes in :paramref:`adata` `.var_names` to plot.
    final
        Whether to consider cells going to final states or vice versa.
    kind
        Variant of the heatmap.

        - If `'genes'`, group by :paramref:`genes` for each lineage in :paramref:`lineage_names`.
        - If `'lineages'`, group by :paramref:`lineage_names` for each gene in :paramref:`genes`.
    lineage_names
        Names of the lineages for which to plot.
    start_lineage
        Lineage from which to select cells with lowest pseudotime as starting points.
        If specified, the trends start at the earliest pseudotime point within that lineage,
        otherwise they start from time `0`.
    end_lineage
        Lineage from which to select cells with highest pseudotime as endpoints.
        If specified, the trends end at the latest pseudotime point within that lineage,
        otherwise, it is determined automatically.
    lineage_height
        Height of a bar when :paramref:`kind` ='lineages'.
    xlabel
        Label on the x-axis. If `None`, it is determined based on :paramref:`time_key`.
    cluster_genes
        Whether to use :func:`seaborn.clustermap` when :paramref:`kind` `='lineages'`.
    cmap
        Colormap to use when visualizing the smoothed expression.
    n_jobs
        Number of parallel jobs. If `-1`, use all available cores. If `None` or `1`, the execution is sequential.
    backend
        Which backend to use for multiprocessing.
        See :class:`joblib.Parallel` for valid options.
    figsize
        Size of the figure.
        If `None`, it will be set to (15, len(:paramref:`genes`) + len(:paramref:`lineage_names`)).
    dpi
        Dots per inch.
    save
        Filename where to save the plot.
        If `None`, just shows the plot.
    show_progress_bar
        Whether to show a progress bar tracking models fitted.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.Model.prepare`.

    Returns
    -------
    None
        Nothing, just plots the heatmap variant depending on :paramref:`kind`.
        Optionally saves the figure based on :paramref:`save`.
    """
    def gene_per_lineage():
        def color_fill_rec(ax, x, y1, y2, colors=None, cmap=cmap, **kwargs):
            dx = x[1] - x[0]

            for (color, x, y1, y2) in zip(cmap(colors), x, y1, y2):
                ax.add_patch(
                    plt.Rectangle((x, y1),
                                  dx,
                                  y2 - y1,
                                  color=color,
                                  ec=color,
                                  **kwargs))

            ax.plot(x, y2, lw=0)

        fig, axes = plt.subplots(
            nrows=len(genes),
            figsize=(15, len(genes) +
                     len(lineages)) if figsize is None else figsize,
            dpi=dpi,
        )

        if not isinstance(axes, Iterable):
            axes = [axes]
        axes = np.ravel(axes)

        for ax, (gene, models) in zip(axes, data.items()):
            c = np.array([m.y_test for m in models.values()])
            c_min, c_max = np.nanmin(c), np.nanmax(c)
            norm = colors.Normalize(vmin=c_min, vmax=c_max)

            ix = 0
            ys = [ix]

            for x, c in ((m.x_test, m.y_test) for m in models.values()):
                y = np.ones_like(x)
                color_fill_rec(ax,
                               x,
                               y * ix,
                               y * (ix + lineage_height),
                               colors=norm(c))
                ix += lineage_height
                ys.append(ix)

            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.nanmin(xs), np.nanmax(xs)
            ax.set_xticks(np.linspace(x_min, x_max, 11))

            ax.set_yticks(np.array(ys) + lineage_height / 2)
            ax.set_yticklabels(lineages)
            ax.set_title(gene)
            ax.set_ylabel("Lineage")

            for pos in ["top", "bottom", "left", "right"]:
                ax.spines[pos].set_visible(False)

            cax, _ = mpl.colorbar.make_axes(ax)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          norm=norm,
                                          cmap=cmap,
                                          label="Expression")

            ax.tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=True,
                labelbottom=False,
            )

        ax.xaxis.set_major_formatter(FormatStrFormatter("%.1f"))
        ax.tick_params(
            top=False,
            bottom=True,
            left=False,
            right=False,
            labelleft=True,
            labelbottom=True,
        )
        ax.set_xlabel(xlabel)

        return fig

    def lineage_per_gene():
        data_t = defaultdict(dict)  # transpose
        for gene, lns in data.items():
            for ln, d in lns.items():
                data_t[ln][gene] = d

        fig, ax = None, None
        if not cluster_genes:
            fig, axes = plt.subplots(
                nrows=len(lineages),
                figsize=(15, 5 + len(genes)) if figsize is None else figsize,
                dpi=dpi,
            )
            fig.subplots_adjust(hspace=hspace, bottom=0)

            if not isinstance(axes, Iterable):
                axes = [axes]
            axes = np.ravel(axes)
        else:
            axes = [None] * len(data)

        for ax, (lname, models) in zip(axes, data_t.items()):
            df = pd.DataFrame([m.y_test for m in models.values()], index=genes)
            df.index.name = "Genes"
            if cluster_genes:
                g = sns.clustermap(
                    df,
                    cmap=cmap,
                    xticklabels=False,
                    cbar_kws={"label": "Expression"},
                    row_cluster=True,
                    col_cluster=False,
                )
                g.ax_heatmap.set_title(lname)
                fig = g.fig
            else:
                xs = np.array([m.x_test for m in models.values()])
                x_min, x_max = np.nanmin(xs), np.nanmax(xs)

                sns.heatmap(
                    df,
                    ax=ax,
                    cmap=cmap,
                    xticklabels=False,
                    cbar_kws={"label": "Expression"},
                )

                ax.set_title(lname)
                ax.set_xticks(np.linspace(0, len(df.columns), 10))
                ax.set_xticklabels(
                    list(
                        map(lambda n: round(n, 3),
                            np.linspace(x_min, x_max, 10))),
                    rotation=90,
                )

        if not cluster_genes:
            ax.set_xlabel(xlabel)

        return fig

    lineage_key = str(LinKey.FORWARD if final else LinKey.BACKWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[lineage_key].names

    for lineage_name in lineages:
        _ = adata.obsm[lineage_key][lineage_name]

    if isinstance(genes, str):
        genes = [genes]
    check_collection(adata, genes, "var_names")

    if isinstance(start_lineage, (str, type(None))):
        start_lineage = [start_lineage] * len(lineages)
    if isinstance(end_lineage, (str, type(None))):
        end_lineage = [end_lineage] * len(lineages)

    xlabel = kwargs.get("time_key", None) if xlabel is None else xlabel

    _ = kwargs.pop("start_lineage", None)
    _ = kwargs.pop("end_lineage", None)

    for typp, clusters in zip(["Start", "End"], [start_lineage, end_lineage]):
        for cl in filter(lambda c: c is not None, clusters):
            if cl not in lineages:
                raise ValueError(
                    f"{typp} lineage `{cl!r}` not found in lineage names.")

    kwargs["models"] = _create_models(model, genes, lineages)
    if _is_any_gam_mgcv(kwargs["models"]):
        logg.debug(
            "DEBUG: Setting backend to multiprocessing because model is `GamMGCV`"
        )
        backend = "multiprocessing"

    n_jobs = _get_n_cores(n_jobs, len(genes))
    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    data = parallelize(
        _fit,
        genes,
        unit="gene",
        backend=backend,
        n_jobs=n_jobs,
        extractor=lambda data: {k: v
                                for d in data for k, v in d.items()},
        show_progress_bar=show_progress_bar,
    )(lineages, start_lineage, end_lineage, **kwargs)
    logg.info("    Finish", time=start)
    logg.debug(f"DEBUG: Plotting {kind} heatmap")

    if kind == "genes":
        fig = gene_per_lineage()
    elif kind == "lineages":
        fig = lineage_per_gene()
    else:
        raise ValueError(
            f"Unknown heatmap kind `{kind!r}`. Valid options are: `'lineages'`, `'genes'`."
        )

    if save is not None and fig is not None:
        save_fig(fig, save)
コード例 #4
0
ファイル: _cluster_fates.py プロジェクト: dpeerlab/cellrank
def similarity_plot(
    adata: AnnData,
    cluster_key: str = "clusters",
    clusters: Optional[List[str]] = None,
    n_samples: int = 1000,
    cmap: mpl.colors.ListedColormap = cm.viridis,
    fontsize: float = 14,
    rotation: float = 45,
    title: Optional[str] = "similarity",
    figsize: Tuple[float, float] = (12, 10),
    dpi: Optional[int] = None,
    final: bool = True,
    save: Optional[Union[str, Path]] = None,
) -> None:
    """
    Compare clusters with respect to their root/final probabilities.

    For each cluster, we compute how likely an 'average cell' is to go towards to final states/come from the root
    states. We then compare these averaged probabilities using Cramér's V statistic, see
    `here <https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V>`_. The similarity is defined as :math:`1 - Cramér's V`.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/similarity_plot.png
       :width: 400px
       :align: center

    Params
    ------
    adata: :class:`anndata.AnnData`
        Annotated data object.
    cluster_key
        Key in :paramref:`adata` `.obs` corresponding the the clustering.
    clusters
        Clusters in :paramref:`adata` `.obs` to consider.
        If `None`, all cluster will be considered.
    n_samples
        Number of samples per cluster.
    cmap
        Colormap to use.
    fontsize
        Font size of the labels.
    rotation
        Rotation of labels on x-axis.
    figsize
        Size of the figure.
    title
        Title of the figure.
    dpi
        Dots per inch.
    final
        Whether to consider cells going to final states or vice versa.
    save
        Filename where to save the plot.
        If `None`, just shows the plot.

    Returns
    -------
    None
        Nothing, just plots the similarity matrix.
        Optionally saves the figure based on :paramref:`save`.
    """

    logg.debug("DEBUG: Getting the counts")
    data = _counts(
        adata,
        cluster_key=cluster_key,
        clusters=clusters,
        n_samples=n_samples,
        final=final,
    )

    cluster_names = list(data.keys())
    logg.debug("DEBUG: Calculating Cramer`s V statistic")
    sim = [[1 - _cramers_v(data[name2], data[name]) for name in cluster_names]
           for name2 in cluster_names]

    # Plotting function
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    im = ax.imshow(sim, cmap=cmap)

    ax.set_xticks(range(len(cluster_names)))
    ax.set_yticks(range(len(cluster_names)))

    ax.set_xticklabels(cluster_names, fontsize=fontsize, rotation=rotation)
    ax.set_yticklabels(cluster_names, fontsize=fontsize)

    ax.set_title(title)
    ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)

    cbar = ax.figure.colorbar(im,
                              ax=ax,
                              norm=mpl.colors.Normalize(vmin=0, vmax=1))
    cbar.set_ticks(np.linspace(0, 1, 10))

    if save is not None:
        save_fig(fig, save)

    fig.show()
コード例 #5
0
def gene_trends(
    adata: AnnData,
    model: _model_type,
    genes: Union[str, Sequence[str]],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    data_key: str = "X",
    final: bool = True,
    start_lineage: Optional[Union[str, Sequence[str]]] = None,
    end_lineage: Optional[Union[str, Sequence[str]]] = None,
    conf_int: bool = True,
    same_plot: bool = False,
    hide_cells: bool = False,
    perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float,
                                                             float]]]] = None,
    lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
    abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
    cell_color: str = "black",
    color: str = "black",
    cell_alpha: float = 0.6,
    lineage_alpha: float = 0.2,
    size: float = 15,
    lw: float = 2,
    show_cbar: bool = True,
    margins: float = 0.015,
    sharey: bool = False,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    ncols: int = 2,
    n_jobs: Optional[int] = 1,
    backend: str = "multiprocessing",
    ext: str = "png",
    suptitle: Optional[str] = None,
    save: Optional[Union[str, Path]] = None,
    dirname: Optional[str] = None,
    plot_kwargs: Mapping = MappingProxyType({}),
    show_progres_bar: bool = True,
    **kwargs,
) -> None:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any `scikit-learn` model wrapped in :class:`cellrank.ul.models.SKLearnModel`
    to fit gene expression, where we take the lineage weights into account in the loss function.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/gene_trends.png
       :width: 400px
       :align: center

    Params
    ------
    adata : :class:`anndata.AnnData`
        Annotated data object.
    genes
        Genes in :paramref:`adata` `.var_names` to plot.
    model
        Model to fit.

        - If a :class:`dict`, gene and lineage specific models can be specified. Use `'*'` to indicate
        all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`.
    lineage_names
        Lineages names for which to show the gene expression.
    data_key
        Key in :paramref:`adata` `.layers` or `'X'` for :paramref:`adata` `.X` where the data is stored.
    final
        Whether to consider cells going to final states or vice versa.
    start_lineage
        Lineage from which to select cells with lowest pseudotime as starting points.
        If specified, the trends start at the earliest pseudotime within that lineage,
        otherwise they start from time `0`.
    end_lineage
        Lineage from which to select cells with highest pseudotime as endpoints.
        If specified, the trends end at the latest pseudotime within that lineage,
        otherwise, it is determined automatically.
    conf_int
        Whether to compute and show confidence intervals.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all the cells.
    perc
        Percentile for colors. Valid values are in range `[0, 100]`.
        This can improve visualization. Can be specified separately for each lineage separately.
    lineage_cmap
        Colormap to use when coloring in the lineages.
        Only used when :paramref:`same_plot` `=True`.
    abs_prob_cmap
        Colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when :paramref:`same_plot` `=False`.
    cell_color
        Color of the cells when not visualizing absorption probabilities.
        Only used when :paramref:`same_plot` `=True`.
    color
        Color for the lineages, when each lineage is on
        separate plot, otherwise according to :paramref:`lineage_cmap`.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    show_cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ.
    margins
        Margins around the plot.
    sharey
        Whether to share y-axis.
        Only used when :paramref:`same_plot` `=False`.
    figsize
        Size of the figure.
    dpi
        Dots per inch.
    ncols
        Number of columns of the plot when plotting multiple genes.
        Only used when :paramref:`same_plot` `=True`.
    suptitle
        Suptitle of the figure.
        Only used when :paramref:`same_plot` `=True`.
    n_jobs
        Number of parallel jobs. If `-1`, use all available cores. If `None` or `1`, the execution is sequential.
    backend
        Which backend to use for multiprocessing.
        See :class:`joblib.Parallel` for valid options.
    ext
        Extension to use when saving files, such as `'pdf'`.
        Only used when :paramref:`same_plot` `=False`.
    save
        Filename where to save the plots.
        If `None`, just show the plots.
    dirname
        Directory where to save the plots, one per gene in :paramref:`genes`.
        If `None`, just show the plots.
        Only used when :paramref:`same_plot` `=False`.
        The figures will be saved as :paramref:`dirname` /`{gene}`. :paramref:`ext`.
    plot_kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.Model.plot`.
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.Model.prepare`.

    Returns
    -------
    None
        Nothings just plots and optionally saves the plots.
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _make_unique(genes)

    if data_key != "obs":
        check_collection(adata, genes, "var_names")
    else:
        check_collection(adata, genes, "obs")

    nrows = int(np.ceil(len(genes) / ncols))
    fig = None
    axes = [None] * len(genes)

    if same_plot:
        fig, axes = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            sharey=sharey,
            figsize=(15 * ncols, 10 * nrows) if figsize is None else figsize,
        )
        axes = np.ravel(axes)
    elif dirname is not None:
        figdir = sc.settings.figdir
        if figdir is None:
            raise RuntimeError(
                f"Invalid combination: figures directory `cellrank.settings.figdir` is `None`, "
                f"but dirname was specified to `{dirname}`.")

        if os.path.isabs(dirname):
            if not os.path.isdir(dirname):
                os.makedirs(dirname, exist_ok=True)
        elif not os.path.isdir(os.path.join(figdir, dirname)):
            os.makedirs(os.path.join(figdir, dirname), exist_ok=True)
    elif save is not None:
        logg.warning(
            "No directory specified for saving. Ignoring `save` argument")

    ln_key = str(LinKey.FORWARD if final else LinKey.BACKWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(map(lambda ln: ln is None,
                 lineages)):  # no lineage, all the weights are 1
        lineages = [None]
        show_cbar = False
        logg.debug("DEBUG: All lineages are `None`, setting weights to be `1`")
    lineages = _make_unique(lineages)

    for ln in filter(lambda ln: ln is not None, lineages):
        _ = adata.obsm[ln_key][ln]
    n_lineages = len(lineages)

    if isinstance(start_lineage, (str, type(None))):
        start_lineage = [start_lineage] * n_lineages
    if isinstance(end_lineage, (str, type(None))):
        end_lineage = [end_lineage] * n_lineages

    if len(start_lineage) != n_lineages:
        raise ValueError(
            f"Expected the number of start lineages to be the same as number of lineages "
            f"({n_lineages}), found `{len(start_lineage)}`.")
    if len(end_lineage) != n_lineages:
        raise ValueError(
            f"Expected the number of end lineages to be the same as number of lineages "
            f"({n_lineages}), found `{len(start_lineage)}`.")

    kwargs["models"] = _create_models(model, genes, lineages)
    kwargs["data_key"] = data_key
    kwargs["final"] = final
    kwargs["conf_int"] = conf_int

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = kwargs.get("time_key", None)

    if _is_any_gam_mgcv(kwargs["models"]):
        logg.debug(
            "DEBUG: Setting backend to multiprocessing because model is `GamMGCV`"
        )
        backend = "multiprocessing"

    n_jobs = _get_n_cores(n_jobs, len(genes))

    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    models = parallelize(
        _fit,
        genes,
        unit="gene" if data_key != "obs" else "obs",
        backend=backend,
        n_jobs=n_jobs,
        extractor=lambda modelss:
        {k: v
         for m in modelss for k, v in m.items()},
        show_progress_bar=show_progres_bar,
    )(lineages, start_lineage, end_lineage, **kwargs)
    logg.info("    Finish", time=start)

    logg.debug("DEBUG: Plotting trends")
    for i, (gene, ax) in enumerate(zip(genes, axes)):
        f = (None if (same_plot or dirname is None) else os.path.join(
            dirname, f"{gene}.{ext}"))
        _trends_helper(
            adata,
            models,
            gene=gene,
            lineage_names=lineages,
            ln_key=ln_key,
            same_plot=same_plot,
            hide_cells=hide_cells,
            perc=perc,
            cmap=lineage_cmap,
            abs_prob_cmap=abs_prob_cmap,
            cell_color=cell_color,
            color=color,
            alpha=cell_alpha,
            lineage_alpha=lineage_alpha,
            size=size,
            lw=lw,
            show_cbar=show_cbar,
            margins=margins,
            sharey=sharey,
            dpi=dpi,
            figsize=figsize,
            fig=fig,
            ax=ax,
            save=f,
            **plot_kwargs,
        )

    if same_plot:
        for j in range(len(genes), len(axes)):
            axes[j].remove()

        fig.suptitle(suptitle)

        if save is not None:
            save_fig(fig, save)
コード例 #6
0
ファイル: _cluster_fates.py プロジェクト: opnumten/cellrank
def cluster_fates(
    adata: AnnData,
    cluster_key: Optional[str] = "louvain",
    clusters: Optional[Union[str, Sequence[str]]] = None,
    lineages: Optional[Union[str, Sequence[str]]] = None,
    mode: str = "bar",
    final: bool = True,
    basis: Optional[str] = None,
    show_cbar: bool = True,
    ncols: Optional[int] = None,
    sharey: bool = False,
    save: Optional[Union[str, Path]] = None,
    legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}),
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    **kwargs,
) -> None:
    """
    Produces plots that aggregate lineage probabilities to a cluster level.

    This can be used to investigate how likely a certain cluster is to go to the final cells, or in turn to have
    descended from the root cells. For mode `'paga'` and `'paga_pie'`, we use *PAGA*, see [Wolf19]_.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/cluster_fates.png
       :width: 400px
       :align: center

    Params
    ------
    adata : :class:`anndata.AnnData`
        Annotated data object.
    cluster_key
        Key in :paramref:`adata` `.obs` containing the clusters.
    clusters
        Clusters to visualize.
        If `None`, all clusters will be plotted.
    lineages
        Lineages for which to visualize absorption probabilities.
        If `None`, use all available lineages.
    mode
        Type of plot to show.

        - If `'bar'`, plot barplots for specified :paramref:`clusters` and :paramref:`endpoints`.
        - If `'paga'`, plot `N` :func:`scanpy.pl.paga` plots, one for each endpoint in :paramref:`endpoints`.
        - If `'paga_pie'`, visualize absorption probabilities as a pie chart for each cluster
          for the given :paramref:`endpoints`.
        - If `'violin'`, group the data by lineages and plot the fate distribution per cluster.

        Best for looking at the distribution of fates within one cluster.
    dpi
        Dots per inch.
    final
        Whether to consider cells going to final states or vice versa.
    basis
        Basis for scatterplot to use when :paramref:`mode` `='paga_pie'`. If `None`, don't show the scatterplot.
    show_cbar
        Whether to show colorbar when :paramref:`mode` is `'paga_pie'`.
    ncols
        Number of columns when :paramref:`mode` is `'bar'` or `'paga'`.
    sharey
        Whether to share y-axis when :paramref:`mode` is `'bar'`.
    figsize
        Size of the figure.
    save
        Filename where to save the plots.
        If `None`, just shows the plot.
    legend_kwargs
        Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position.
    figsize
        Size of the figure. If `None`, it will be set automatically.
    dpi
        Dots per inch.
    kwargs
        Keyword arguments for :func:`scvelo.pl.paga`, :func:`scanpy.pl.violin` or :func:`matplotlib.pyplot.bar`,
        depending on :paramref:`mode`.

    Returns
    -------
    None
        Nothing, just plots the fates for specified :paramref:`clusters` and :paramref:`lineages`.
        Optionally saves the figure based on :paramref:`save`.
    """
    def plot_bar():
        cols = 4 if ncols is None else ncols
        n_rows = ceil(len(clusters) / cols)
        fig = plt.figure(None, (3.5 * cols,
                                5 * n_rows) if figsize is None else figsize,
                         dpi=dpi)
        fig.tight_layout()

        gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.7, hspace=0.9)

        ax = None
        colors = list(adata.obsm[lk][:, lin_names].colors)

        for g, k in zip(gs, d.keys()):
            current_ax = fig.add_subplot(g, sharey=ax)
            current_ax.bar(
                x=np.arange(len(lin_names)),
                height=d[k][0],
                color=colors,
                yerr=d[k][1],
                ecolor="black",
                capsize=10,
                **kwargs,
            )
            if sharey:
                ax = current_ax

            current_ax.set_xticks(np.arange(len(lin_names)))
            current_ax.set_xticklabels(lin_names, rotation="vertical")
            current_ax.set_xlabel(points)
            current_ax.set_ylabel("Absorption probability")
            current_ax.set_title(k)

        return fig

    def plot_paga():
        kwargs["save"] = None
        kwargs["show"] = False
        if "cmap" not in kwargs:
            kwargs["cmap"] = cm.viridis

        cols = len(lin_names) if ncols is None else ncols
        nrows = ceil(len(lin_names) / cols)
        fig, axes = plt.subplots(
            nrows,
            cols,
            figsize=(6 * cols, 4 * nrows) if figsize is None else figsize,
            constrained_layout=True,
            dpi=dpi,
        )

        i = 0
        axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes)
        vmin, vmax = np.inf, -np.inf

        for i, (ax, lineage_name) in enumerate(zip(axes, lin_names)):
            colors = [v[0][i] for v in d.values()]
            vmin, vmax = np.nanmin(colors + [vmin]), np.nanmax(colors + [vmax])
            kwargs["ax"] = ax
            kwargs["colors"] = tuple(colors)
            kwargs["title"] = lineage_name

            sc.pl.paga(adata, **kwargs)

        if show_cbar:
            norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
            cax, _ = mpl.colorbar.make_axes(
                ax, aspect=100)  # new matplotlib feature
            _ = mpl.colorbar.ColorbarBase(cax,
                                          norm=norm,
                                          cmap=kwargs["cmap"],
                                          label="Absorption probability")

        for ax in axes[i + 1:]:
            ax.remove()

        return fig

    def plot_paga_pie():
        colors = list(adata.obsm[lk][:, lin_names].colors)
        colors = {
            i: odict(zip(colors, mean))
            for i, (mean, _) in enumerate(d.values())
        }

        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

        kwargs["ax"] = ax
        kwargs["show"] = False
        kwargs["colorbar"] = False  # has to be disabled

        kwargs["node_colors"] = colors
        kwargs.pop("save", None)  # we will handle saving

        kwargs["transitions"] = kwargs.get("transitions", None)
        kwargs["legend_loc"] = kwargs.get("legend_loc", None) or "on data"

        if basis is not None:
            kwargs["basis"] = basis
            kwargs["scatter_flag"] = True
            kwargs["color"] = cluster_key

        scv.pl.paga(adata, **kwargs)

        if basis is not None and kwargs["legend_loc"] not in ("none",
                                                              "on data"):
            first_legend = _position_legend(
                ax,
                legend_loc=kwargs["legend_loc"],
                **{k: v
                   for k, v in legend_kwargs.items() if k != "loc"},
                title=cluster_key,
            )
            fig.add_artist(first_legend)

        if legend_kwargs.get("loc", None) is not None:
            # we need to use these, because scvelo can have its own handles and
            # they would be plotted here
            handles = []
            for lineage_name, color in zip(lin_names, colors[0].keys()):
                handles += [ax.scatter([], [], label=lineage_name, c=color)]
            if len(colors[0].keys()) != len(adata.obsm[lk].names):
                handles += [ax.scatter([], [], label="Rest", c="grey")]

            ax.legend(**legend_kwargs, handles=handles, title=points)

        return fig

    def plot_violin():
        kwargs["show"] = False
        kwargs.pop("ax", None)
        kwargs.pop("keys", None)
        kwargs.pop("save", None)  # we will handle saving
        kwargs["groupby"] = cluster_key
        if kwargs.get("rotation", None) is None:
            kwargs["rotation"] = 90

        data = adata.obsm[lk]
        to_clean = []

        for i, name in enumerate(lin_names):
            if name not in adata.obs_keys():
                to_clean.append(name)
                adata.obs[name] = np.array(
                    data[:, name])  # TODO: better approach - dummy adata

        cols = len(lin_names) if ncols is None else ncols
        nrows = ceil(len(lin_names) / cols)
        fig, axes = plt.subplots(
            nrows,
            cols,
            figsize=(6 * cols, 4 * nrows) if figsize is None else figsize,
            dpi=dpi,
        )
        if not isinstance(axes, np.ndarray):
            axes = [axes]
        axes = np.ravel(axes)

        i = 0
        for i, (name, ax) in enumerate(zip(lin_names, axes)):
            ax.set_title(name)  # ylabel not yet supported
            sc.pl.violin(adata, keys=[name], ax=ax, **kwargs)
        for ax in axes[i + 1:]:
            ax.remove()
        for name in to_clean:
            del adata.obs[name]

        return fig

    def plot_violin_no_cluster_key():
        kwargs["show"] = False
        kwargs.pop("ax", None)
        kwargs.pop("keys", None)  # don't care
        kwargs.pop("save", None)
        kwargs["groupby"] = points

        data = np.ravel(np.array(adata.obsm[lk]).T)[..., np.newaxis]
        dadata = AnnData(np.zeros_like(data))
        dadata.obs["Absorption probability"] = data
        dadata.obs[points] = (pd.Series(
            np.ravel([[n] * adata.n_obs for n in adata.obsm[lk].names
                      ])).astype("category").values)
        dadata.uns[f"{points}_colors"] = adata.obsm[lk].colors

        fig, ax = plt.subplots(figsize=figsize if figsize is not None else
                               (8, 6),
                               dpi=dpi)
        ax.set_title("All Clusters")
        sc.pl.violin(dadata, keys=["Absorption probability"], ax=ax, **kwargs)

        return fig

    if mode not in _cluster_fates_modes:
        raise ValueError(
            f"Invalid mode: `{mode!r}`. Valid options are: `{_cluster_fates_modes}`."
        )
    if cluster_key is not None:
        if cluster_key not in adata.obs:
            raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.")
    elif mode not in ("bar", "violin"):
        raise ValueError(
            f"Not specifying cluster key is only available for modes `'bar'` and `'violin'`."
        )

    if cluster_key is not None:
        if clusters is not None:
            if isinstance(clusters, str):
                clusters = [clusters]
            clusters = _make_unique(clusters)
            if mode in ("paga", "paga_pie"):
                logg.debug(
                    f"DEBUG: Setting `clusters` to all available ones because of `mode={mode!r}`"
                )
                clusters = list(adata.obs[cluster_key].cat.categories)
            else:
                for cname in clusters:
                    if cname not in adata.obs[cluster_key].cat.categories:
                        raise KeyError(
                            f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`"
                        )
        else:
            clusters = list(adata.obs[cluster_key].cat.categories)
    else:
        clusters = ["All"]

    lk = str(LinKey.FORWARD if final else LinKey.BACKWARD)
    points = "Endpoints" if final else "Startpoints"
    if lk not in adata.obsm:
        raise KeyError(f"Lineages key `{lk!r}` not found in `adata.obsm`.")

    if lineages is not None:
        if isinstance(lineages, str):
            lineages = [lineages]
        lineages = _make_unique(lineages)
        for ep in lineages:
            if ep not in adata.obsm[lk].names:
                raise ValueError(
                    f"Endpoint `{ep!r}` not found in `adata.obsm[{lk!r}].names`."
                )
        lin_names = list(lineages)
    else:
        # must be list for sc.pl.violin, else cats str
        lin_names = list(adata.obsm[lk].names)

    if mode == "violin" and clusters != ["All"]:
        # TODO: temporary fix, until subclassing is made ready
        names, colors = adata.obsm[lk].names, adata.obsm[lk].colors
        adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy()
        adata.obsm[lk] = Lineage(adata.obsm[lk], names=names, colors=colors)

    d = odict()
    for name in clusters:
        mask = (np.ones((adata.n_obs, ), dtype=np.bool) if name == "All" else
                (adata.obs[cluster_key] == name).values)
        mask = list(np.array(mask, dtype=np.bool))
        data = adata.obsm[lk][mask, lin_names].X
        mean = np.nanmean(data, axis=0)
        std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0])
        d[name] = [mean, std]

    logg.debug(f"DEBUG: Using mode: `{mode!r}`")
    if mode == "bar":
        fig = plot_bar()
    elif mode == "paga":
        if "paga" not in adata.uns:
            raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.")
        fig = plot_paga()
    elif mode == "paga_pie":
        if "paga" not in adata.uns:
            raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.")
        fig = plot_paga_pie()
    elif mode == "violin":
        fig = plot_violin_no_cluster_key(
        ) if cluster_key is None else plot_violin()
    else:
        raise ValueError(
            f"Invalid mode `{mode!r}`. Valid options are: `{_cluster_fates_modes}`."
        )

    if save is not None:
        save_fig(fig, save)

    fig.show()
コード例 #7
0
ファイル: _cluster_fates.py プロジェクト: dpeerlab/cellrank
def cluster_fates(
    adata: AnnData,
    cluster_key: Optional[str] = "louvain",
    clusters: Optional[Union[str, Sequence[str]]] = None,
    lineages: Optional[Union[str, Sequence[str]]] = None,
    mode: str = "bar",
    final: bool = True,
    basis: Optional[str] = None,
    show_cbar: bool = True,
    ncols: Optional[int] = None,
    sharey: bool = False,
    save: Optional[Union[str, Path]] = None,
    legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}),
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    **kwargs,
) -> None:
    """
    Plot aggregate lineage probabilities at a cluster level.

    This can be used to investigate how likely a certain cluster is to go to the final states, or in turn to have
    descended from the root states. For mode `'paga'` and `'paga_pie'`, we use *PAGA*, see [Wolf19]_.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/cluster_fates.png
       :width: 400px
       :align: center

    Params
    ------
    adata : :class:`anndata.AnnData`
        Annotated data object.
    cluster_key
        Key in :paramref:`adata` `.obs` containing the clusters.
    clusters
        Clusters to visualize.
        If `None`, all clusters will be plotted.
    lineages
        Lineages for which to visualize absorption probabilities.
        If `None`, use all available lineages.
    mode
        Type of plot to show.

        - `'bar'`: barplot, one panel per cluster
        - `'paga'`: scanpy's PAGA, one per root/final state, colored in by fate
        - `'paga_pie'`: scanpy's PAGA with pie charts indicating aggregated fates
        - `'violin'`: violin plots, one per root/final state
        - `'heatmap'`: seaborn heatmap, showing average fates per cluster
        - `'clustermap'`: same as heatmap, but with dendrogram
    dpi
        Dots per inch.
    final
        Whether to consider cells going to final states or vice versa.
    basis
        Basis for scatterplot to use when :paramref:`mode` `='paga_pie'`. If `None`, don't show the scatterplot.
    show_cbar
        Whether to show colorbar when :paramref:`mode` is `'paga_pie'`.
    ncols
        Number of columns when :paramref:`mode` is `'bar'` or `'paga'`.
    sharey
        Whether to share y-axis when :paramref:`mode` is `'bar'`.
    figsize
        Size of the figure.
    save
        Filename where to save the plots.
        If `None`, just shows the plot.
    legend_kwargs
        Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position.
        For `mode='paga_pie'` and `basis='...'`, this controls the placement of the absorption probabilities legend.
    figsize
        Size of the figure. If `None`, it will be set automatically.
    dpi
        Dots per inch.
    kwargs
        Keyword arguments for :func:`scvelo.pl.paga`, :func:`scanpy.pl.violin` or :func:`matplotlib.pyplot.bar`,
        depending on :paramref:`mode`.

    Returns
    -------
    None
        Nothing, just plots the fates for specified :paramref:`clusters` and :paramref:`lineages`.
        Optionally saves the figure based on :paramref:`save`.
    """
    def plot_bar():
        cols = 4 if ncols is None else ncols
        n_rows = ceil(len(clusters) / cols)
        fig = plt.figure(None, (3.5 * cols,
                                5 * n_rows) if figsize is None else figsize,
                         dpi=dpi)
        fig.tight_layout()

        gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.7, hspace=0.9)

        ax = None
        colors = list(adata.obsm[lk][:, lin_names].colors)

        for g, k in zip(gs, d.keys()):
            current_ax = fig.add_subplot(g, sharey=ax)
            current_ax.bar(
                x=np.arange(len(lin_names)),
                height=d[k][0],
                color=colors,
                yerr=d[k][1],
                ecolor="black",
                capsize=10,
                **kwargs,
            )
            if sharey:
                ax = current_ax

            current_ax.set_xticks(np.arange(len(lin_names)))
            current_ax.set_xticklabels(
                lin_names, rotation=xrot if has_xrot else "vertical")
            if not is_all:
                current_ax.set_xlabel(points)
            current_ax.set_ylabel("probability")
            current_ax.set_title(k)

        return fig

    def plot_paga():
        kwargs["save"] = None
        kwargs["show"] = False
        if "cmap" not in kwargs:
            kwargs["cmap"] = cm.viridis

        cols = len(lin_names) if ncols is None else ncols
        nrows = ceil(len(lin_names) / cols)
        fig, axes = plt.subplots(
            nrows,
            cols,
            figsize=(7 * cols, 4 * nrows) if figsize is None else figsize,
            constrained_layout=True,
            dpi=dpi,
        )

        i = 0
        axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes)
        vmin, vmax = np.inf, -np.inf

        if basis is not None:
            kwargs["basis"] = basis
            kwargs["scatter_flag"] = True
            kwargs["color"] = cluster_key

        for i, (ax, lineage_name) in enumerate(zip(axes, lin_names)):
            colors = [v[0][i] for v in d.values()]
            kwargs["ax"] = ax
            kwargs["colors"] = tuple(colors)
            kwargs["title"] = f"{dir_prefix} {lineage_name}"

            scv.pl.paga(adata, **kwargs)

        if show_cbar:
            norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
            cax, _ = mpl.colorbar.make_axes(
                ax, aspect=100)  # new matplotlib feature
            _ = mpl.colorbar.ColorbarBase(cax,
                                          norm=norm,
                                          cmap=kwargs["cmap"],
                                          label="probability")

        for ax in axes[i + 1:]:  # noqa
            ax.remove()

        return fig

    def plot_paga_pie():
        colors = list(adata.obsm[lk][:, lin_names].colors)
        colors = {
            i: odict(zip(colors, mean))
            for i, (mean, _) in enumerate(d.values())
        }

        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

        kwargs["ax"] = ax
        kwargs["show"] = False
        kwargs["colorbar"] = False  # has to be disabled
        kwargs["show"] = False

        kwargs["node_colors"] = colors
        kwargs.pop("save", None)  # we will handle saving

        kwargs["transitions"] = kwargs.get("transitions", None)
        if "legend_loc" in kwargs:
            orig_ll = kwargs["legend_loc"]
            if orig_ll != "on data":
                kwargs["legend_loc"] = "none"  # we will handle legend
        else:
            orig_ll = None
            kwargs["legend_loc"] = "on data"

        if basis is not None:
            kwargs["basis"] = basis
            kwargs["scatter_flag"] = True
            kwargs["color"] = cluster_key

        ax = scv.pl.paga(adata, **kwargs)
        ax.set_title(kwargs.get("title", cluster_key))

        if basis is not None and orig_ll not in ("none", "on data", None):
            handles = []
            for cluster_name, color in zip(
                    adata.obs[f"{cluster_key}"].cat.categories,
                    adata.uns[f"{cluster_key}_colors"],
            ):
                handles += [ax.scatter([], [], label=cluster_name, c=color)]
            first_legend = _position_legend(
                ax,
                legend_loc=orig_ll,
                handles=handles,
                **{k: v
                   for k, v in legend_kwargs.items() if k != "loc"},
                title=cluster_key,
            )
            fig.add_artist(first_legend)

        if legend_kwargs.get("loc", None) not in ("none", "on data", None):
            # we need to use these, because scvelo can have its own handles and
            # they would be plotted here
            handles = []
            for lineage_name, color in zip(lin_names, colors[0].keys()):
                handles += [ax.scatter([], [], label=lineage_name, c=color)]
            if len(colors[0].keys()) != len(adata.obsm[lk].names):
                handles += [ax.scatter([], [], label="Rest", c="grey")]

            second_legend = _position_legend(
                ax,
                legend_loc=legend_kwargs["loc"],
                handles=handles,
                **{k: v
                   for k, v in legend_kwargs.items() if k != "loc"},
                title=points,
            )
            fig.add_artist(second_legend)

        return fig

    def plot_violin():
        kwargs.pop("ax", None)
        kwargs.pop("keys", None)
        kwargs.pop("save", None)  # we will handle saving

        kwargs["show"] = False
        kwargs["groupby"] = cluster_key
        kwargs["rotation"] = xrot

        data = adata.obsm[lk]
        to_clean = []

        for name in lin_names:
            # TODO: once ylabel is implemented, the prefix isn't necessary
            key = f"{dir_prefix} {name}"
            if key not in adata.obs_keys():
                to_clean.append(key)
                adata.obs[key] = np.array(
                    data[:, name])  # TODO: better approach - dummy adata

        cols = len(lin_names) if ncols is None else ncols
        nrows = ceil(len(lin_names) / cols)
        fig, axes = plt.subplots(
            nrows,
            cols,
            figsize=(6 * cols, 4 * nrows) if figsize is None else figsize,
            sharey=sharey,
            dpi=dpi,
        )
        if not isinstance(axes, np.ndarray):
            axes = [axes]
        axes = np.ravel(axes)

        i = 0
        for i, (name, ax) in enumerate(zip(lin_names, axes)):
            key = f"{dir_prefix} {name}"
            ax.set_title(key)
            sc.pl.violin(adata,
                         ylabel="" if i else "probability",
                         keys=key,
                         ax=ax,
                         **kwargs)
        for ax in axes[i + 1:]:  # noqa
            ax.remove()
        for name in to_clean:
            del adata.obs[name]

        return fig

    def plot_violin_no_cluster_key():
        kwargs.pop("ax", None)
        kwargs.pop("keys", None)  # don't care
        kwargs.pop("save", None)

        kwargs["show"] = False
        kwargs["groupby"] = points
        kwargs["xlabel"] = None
        kwargs["rotation"] = xrot

        data = np.ravel(np.array(adata.obsm[lk]).T)[..., np.newaxis]
        dadata = AnnData(np.zeros_like(data))
        dadata.obs["probability"] = data
        dadata.obs[points] = (pd.Series(
            np.ravel([[f"{dir_prefix.lower()} {n}"] * adata.n_obs
                      for n in adata.obsm[lk].names
                      ])).astype("category").values)
        dadata.uns[f"{points}_colors"] = adata.obsm[lk].colors

        fig, ax = plt.subplots(figsize=figsize if figsize is not None else
                               (8, 6),
                               dpi=dpi)
        ax.set_title(points.capitalize())
        sc.pl.violin(dadata, keys=["probability"], ax=ax, **kwargs)

        return fig

    def plot_heatmap():
        title = kwargs.pop("title", None)
        if not title:
            title = "average fate per cluster"
        data = pd.DataFrame([mean for mean, _ in d.values()],
                            columns=lin_names,
                            index=clusters).T

        if "cmap" not in kwargs:
            kwargs["cmap"] = "viridis"

        if use_clustermap:
            kwargs["cbar_pos"] = (0, 0.9, 0.025, 0.15) if show_cbar else None
            max_size = float(max(data.shape))

            g = clustermap(
                data,
                robust=True,
                annot=True,
                fmt=".2f",
                row_colors=adata.obsm[lk][lin_names].colors,
                dendrogram_ratio=(
                    0.15 * data.shape[0] / max_size,
                    0.15 * data.shape[1] / max_size,
                ),
                figsize=figsize,
                **kwargs,
            )
            g.ax_heatmap.set_xlabel(cluster_key)
            g.ax_heatmap.set_ylabel("lineage")
            g.ax_col_dendrogram.set_title(title)

            fig = g.fig
            g = g.ax_heatmap
        else:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
            g = heatmap(
                data,
                robust=True,
                annot=True,
                fmt=".2f",
                cbar=show_cbar,
                ax=ax,
                **kwargs,
            )
            ax.set_title(title)
            ax.set_xlabel(cluster_key)
            ax.set_ylabel("lineage")

        g.set_xticklabels(g.get_xticklabels(), rotation=xrot)
        g.set_yticklabels(g.get_yticklabels(), rotation=0)

        return fig

    if mode not in _cluster_fates_modes:
        raise ValueError(
            f"Invalid mode: `{mode!r}`. Valid options are: `{_cluster_fates_modes}`."
        )
    if cluster_key is not None:
        if cluster_key not in adata.obs:
            raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.")
    elif mode not in ("bar", "violin"):
        raise ValueError(
            f"Not specifying cluster key is only available for modes `'bar'` and `'violin'`, found `mode={mode!r}`."
        )

    lk = str(LinKey.FORWARD if final else LinKey.BACKWARD)
    points = "final states" if final else "root states"
    dir_prefix = "To" if final else "From"

    if cluster_key is not None:
        is_all = False
        if clusters is not None:
            if isinstance(clusters, str):
                clusters = [clusters]
            clusters = _make_unique(clusters)
            if mode in ("paga", "paga_pie"):
                logg.debug(
                    f"DEBUG: Setting `clusters` to all available ones because of `mode={mode!r}`"
                )
                clusters = list(adata.obs[cluster_key].cat.categories)
            else:
                for cname in clusters:
                    if cname not in adata.obs[cluster_key].cat.categories:
                        raise KeyError(
                            f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`"
                        )
        else:
            clusters = list(adata.obs[cluster_key].cat.categories)
    else:
        is_all = True
        clusters = [points]

    if lk not in adata.obsm:
        raise KeyError(f"Lineages key `{lk!r}` not found in `adata.obsm`.")

    if lineages is not None:
        if isinstance(lineages, str):
            lineages = [lineages]
        lineages = _make_unique(lineages)
        for ep in lineages:
            if ep not in adata.obsm[lk].names:
                raise ValueError(
                    f"Endpoint `{ep!r}` not found in `adata.obsm[{lk!r}].names`."
                )
        lin_names = list(lineages)
    else:
        # must be list for sc.pl.violin, else cats str
        lin_names = list(adata.obsm[lk].names)

    if mode == "violin" and not is_all:
        adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy()

    d = odict()
    for name in clusters:
        mask = (np.ones((adata.n_obs, ), dtype=np.bool) if is_all else
                (adata.obs[cluster_key] == name).values)
        mask = list(np.array(mask, dtype=np.bool))
        data = adata.obsm[lk][mask, lin_names].X
        mean = np.nanmean(data, axis=0)
        std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0])
        d[name] = [mean, std]

    has_xrot = "xticks_rotation" in kwargs
    xrot = kwargs.pop("xticks_rotation", 45)

    logg.debug(f"DEBUG: Using mode: `{mode!r}`")

    use_clustermap = mode == "clustermap"
    if use_clustermap:
        mode = "heatmap"

    if mode == "bar":
        fig = plot_bar()
    elif mode == "paga":
        if "paga" not in adata.uns:
            raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.")
        fig = plot_paga()
    elif mode == "paga_pie":
        if "paga" not in adata.uns:
            raise KeyError("Compute PAGA first as `scanpy.tl.paga()`.")
        fig = plot_paga_pie()
    elif mode == "violin":
        fig = plot_violin_no_cluster_key(
        ) if cluster_key is None else plot_violin()
    elif mode == "heatmap":
        fig = plot_heatmap()
    else:
        raise ValueError(
            f"Invalid mode `{mode!r}`. Valid options are: `{_cluster_fates_modes}`."
        )

    if save is not None:
        save_fig(fig, save)

    fig.show()
コード例 #8
0
ファイル: _graph.py プロジェクト: dpeerlab/cellrank
def graph(
        data: Union[AnnData, np.ndarray, spmatrix],
        graph_key: Optional[str] = None,
        ixs: Optional[np.array] = None,
        layout: Union[str, Dict, Callable] = nx.kamada_kawai_layout,
        keys: Sequence[KEYS] = ("incoming", ),
        keylocs: Union[KEYLOCS, Sequence[KEYLOCS]] = "uns",
        node_size: float = 400,
        labels: Optional[Union[Sequence[str], Sequence[Sequence[str]]]] = None,
        top_n_edges: Optional[Union[int, Tuple[int, bool, str]]] = None,
        self_loops: bool = True,
        self_loop_radius_frac: Optional[float] = None,
        filter_edges: Optional[Tuple[float, float]] = None,
        edge_reductions: Union[Callable, Sequence[Callable]] = np.sum,
        edge_weight_scale: float = 10,
        edge_width_limit: Optional[float] = None,
        edge_alpha: float = 1.0,
        edge_normalize: bool = False,
        edge_use_curved: bool = True,
        show_arrows: bool = True,
        font_size: int = 12,
        font_color: str = "black",
        cat_cmap: ListedColormap = cm.Set3,
        cont_cmap: ListedColormap = cm.viridis,
        legend_loc: Optional[str] = "best",
        figsize: Tuple[float, float] = (15, 10),
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        layout_kwargs: Dict = MappingProxyType({}),
) -> None:
    """
    Plot a graph, visualizing incoming and outgoing edges or self-transitions.

    This is a utility function to look in more detail at the transition matrix in areas of interest, e.g. around an
    endpoint of development. This function is meant to visualise a small subset of nodes (~100-500) and the most likely
    transitions between them. Note that limiting edges visualized using :paramref:`top_n_edges` will speed things up,
    as well as reduce the visual clutter.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/graph.png
       :width: 400px
       :align: center

    Params
    ------
    data :
        The graph data, stored either in `.uns` [ :paramref:`graph_key` ], or as a sparse or a dense matrix.
    graph_key
        Key in :paramref:`adata` `.uns` where the graph is stored.
        Only used when :paramref:`adata` is :class:`Anndata` object.
    ixs
        Subset of indices of the graph to visualize.
    layout
        Layout to use for graph drawing.

        - If :class:`str`, search for embedding in :paramref:`adata` `.obsm[X_` :paramref:`layout` `]`.
          Use :paramref:`layout_kwargs` = `{'components': [0, 1]}` to select components.
        - If :class:`dict`, keys should be values in interval [0, len(:paramref:`ixs`))
          and values `(x, y)` pairs corresponding to node positions.
    keys
        Keys in :paramref:`adata` `.obs`, :paramref:`adata` `.obsm` or :paramref:`adata` `.uns` to color the nodes.

        - If `'incoming'`, `'outgoing'` or `'self_loops'` to
          visualize reduction (see :paramref:`edge_reductions`) for each node based
          on incoming or outgoing edges, respectively.
    keylocs
        Locations of :paramref:`keys`, can be `'obs'`, `'obsm'` or `'uns'`.
    node_size
        Size of the nodes.
    labels
        Labels of the nodes.
    top_n_edges
        Either top N outgoing edges in descending order or a tuple
        `(top_n_edges, in_ascending_order, {'incoming', 'outgoing'})`.
        If `None`, show all edges.
    self_loops
        Whether visualize self transitions and also to consider them in :paramref:`top_n_edges`.
    self_loop_radius_frac
        Fraction of a unit circle to visualize self transitions.

        If `None`, use :paramref:`node_size` / 1000.
    filter_edges
        Whether to remove all edges not in `[min, max]` interval.
    edge_reductions
        Aggregation function to use when coloring nodes by edge weights.
    edge_weight_scale
        Number by which to scale the width of the edges. Useful when the weights are small.
    edge_width_limit
        Upper bound for the width of the edges. Useful when weights are unevenly distributed.
    edge_alpha
        Alpha channel value for edges and arrows.
    edge_normalize
        If `True`, normalize edges to `[0, 1]` interval prior to applying any scaling or truncation.
    edge_use_curved
        If `True`, use curved edges. This can improve visualization at a small performance cost.
    show_arrows
        Whether to show the arrows. Setting this to `False` may dramatically speed things up.
    font_size
        Font size for node labels.
    font_color
        Label color of the nodes.
    cat_cmap
        Categorical colormap used when :paramref:`keys` contain categorical variables.
    cont_cmap
        Continuous colormap used when :paramref:`keys` contain continuous variables.
    legend_loc
        Location of the legend.
    figsize
        Size of the figure.
    dpi
        Dots per inch.
    save
        Filename where to save the plots.
        If `None`, just shows the plot.
    layout_kwargs
        Additional kwargs for :paramref:`layout`.

    Returns
    -------
    None
        Nothing, just plots the graph.
        Optionally saves the figure based on :paramref:`save`.
    """
    def plot_arrows(curves, G, pos, ax, edge_weight_scale):
        for line, (edge, val) in zip(curves, G.edges.items()):
            if edge[0] == edge[1]:
                continue

            mask = (~np.isnan(line)).all(axis=1)
            line = line[mask, :]
            if not len(line):  # can be all NaNs
                continue

            line = line.reshape((-1, 2))
            X, Y = line[:, 0], line[:, 1]

            node_start = pos[edge[0]]
            # reverse
            if np.where(np.isclose(node_start - line,
                                   [0, 0]).all(axis=1))[0][0]:
                X, Y = X[::-1], Y[::-1]

            mid = len(X) // 2
            posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2])  # noqa

            arrow = FancyArrowPatch(
                posA=posA,
                posB=posB,
                # we clip because too small values
                # cause it to crash
                arrowstyle=ArrowStyle.CurveFilledB(
                    head_length=np.clip(
                        val["weight"] * edge_weight_scale * 4,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                    head_width=np.clip(
                        val["weight"] * edge_weight_scale * 2,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                ),
                color="k",
                zorder=float("inf"),
                alpha=edge_alpha,
                linewidth=0,
            )
            ax.add_artist(arrow)

    def normalize_weights():
        weights = np.array([v["weight"] for v in G.edges.values()])
        minn = np.min(weights)
        weights = (weights - minn) / (np.max(weights) - minn)
        for v, w in zip(G.edges.values(), weights):
            v["weight"] = w

    def remove_top_n_edges():
        if top_n_edges is None:
            return

        if isinstance(top_n_edges, (tuple, list)):
            to_keep, ascending, group_by = top_n_edges
        else:
            to_keep, ascending, group_by = top_n_edges, False, "out"

        if group_by not in ("incoming", "outgoing"):
            raise ValueError(
                "Argument `groupby` in `top_n_edges` must be either `'incoming`' or `'outgoing'`."
            )

        source, target = zip(*G.edges)
        weights = [v["weight"] for v in G.edges.values()]
        tmp = pd.DataFrame({
            "outgoing": source,
            "incoming": target,
            "w": weights
        })

        if not self_loops:
            # remove self loops
            tmp = tmp[tmp["incoming"] != tmp["outgoing"]]

        to_keep = set(
            map(
                tuple,
                tmp.groupby(group_by).apply(
                    lambda g: g.sort_values("w", ascending=ascending).take(
                        range(min(to_keep, len(g)))))[["outgoing",
                                                       "incoming"]].values,
            ))

        for e in list(G.edges):
            if e not in to_keep:
                G.remove_edge(*e)

    def remove_low_weight_edges():
        if filter_edges is None or filter_edges == (None, None):
            return

        minn, maxx = filter_edges
        minn = minn if minn is not None else -np.inf
        maxx = maxx if maxx is not None else np.inf

        for e, attr in list(G.edges.items()):
            if attr["weight"] < minn or attr["weight"] > maxx:
                G.remove_edge(*e)

    _min_edge_weight = 0.00001

    if edge_width_limit is None:
        logg.debug("DEBUG: Not limiting width of edges")
        edge_width_limit = float("inf")

    if self_loop_radius_frac is None:
        self_loop_radius_frac = (node_size /
                                 2000 if node_size >= 200 else node_size /
                                 1000)
        logg.debug(
            f"Setting self loop radius fraction to `{self_loop_radius_frac}`")

    if not isinstance(keylocs, (tuple, list)):
        keylocs = [keylocs] * len(keys)
    elif len(keylocs) == 1:
        keylocs = keylocs * 3
    elif all(map(lambda k: k in ("incoming", "outgoing", "self_loops"), keys)):
        # don't care about keylocs since they are irrelevant
        logg.debug("DEBUG: Ignoring key locations")
        keylocs = [None] * len(keys)

    for k in ("obs", "obsm"):
        if k in keylocs and ixs is None:
            raise ValueError(
                f"Invalid combination: `ixs` is None and found `{k!r}` in `keylocs`."
            )

    if not isinstance(edge_reductions, (tuple, list)):
        edge_reductions = [edge_reductions] * len(keys)
    if not all(map(callable, edge_reductions)):
        raise ValueError("Not all edge_reductions functions are callable.")

    if not isinstance(labels, (tuple, list)):
        labels = [labels] * len(keys)
    elif not len(labels):
        labels = [None] * len(keys)
    elif not isinstance(labels[0], (tuple, list)):
        labels = [labels] * len(keys)

    if len(labels) != len(keys):
        raise ValueError("`Keys` and `labels` must be of the same shape.")

    if isinstance(data, AnnData):
        if graph_key is None:
            raise ValueError(
                "Argument `graph_key` cannot be `None` when `adata` is `anndata.Anndata` object."
            )
        gdata = data.uns[graph_key]["T"]
    elif isinstance(data, (np.ndarray, spmatrix)):
        gdata = data
    else:
        raise TypeError(
            f"Expected argument `data` to be one of `AnnData`, `numpy.ndarray`, `scipy.sparse.spmatrix`, "
            f"found `{type(data).__name__}`")
    is_sparse = issparse(gdata)

    if ixs is not None:
        gdata = gdata[ixs, :][:, ixs]
    else:
        ixs = list(range(gdata.shape[0]))

    start = logg.info("Creating graph")
    G = (nx.from_scipy_sparse_matrix(gdata, create_using=nx.DiGraph)
         if is_sparse else nx.from_numpy_array(gdata, create_using=nx.DiGraph))

    remove_low_weight_edges()
    remove_top_n_edges()
    if edge_normalize:
        normalize_weights()
    logg.info("    Finish", time=start)

    # do NOT recreate the graph, for the edge reductions
    # gdata = nx.to_numpy_array(G)

    fig, axes = plt.subplots(nrows=len(keys),
                             ncols=1,
                             figsize=figsize,
                             dpi=dpi)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = np.ravel(axes)

    if isinstance(layout, str):
        if f"X_{layout}" not in data.obsm:
            raise KeyError(
                f"Unable to find embedding `'X_{layout}'` in `adata.obsm`.")
        components = layout_kwargs.get("components", [0, 1])
        if len(components) != 2:
            raise ValueError(
                f"Components in `layout_kwargs` must be of length `2`, found `{len(components)}`."
            )
        emb = data.obsm[f"X_{layout}"][:, components]
        pos = {i: emb[ix, :] for i, ix in enumerate(ixs)}
        logg.info(f"Embedding graph using `{layout!r}` layout")
    elif isinstance(layout, dict):
        rng = range(len(ixs))
        for k, v in layout.items():
            if k not in rng:
                raise ValueError(
                    f"Key in `layout` must be in `range(len(ixs))`, found `{k}`."
                )
            if len(v) != 2:
                raise ValueError(
                    f"Value in `layout` must be a tuple or list of length 2, found `{v}`."
                )
        pos = layout
        logg.debug("DEBUG: Using pre-specified layout")
    elif callable(layout):
        start = logg.info(
            f"Embedding graph using `{layout.__name__!r}` layout")
        pos = layout(G, **layout_kwargs)
        logg.info("    Finish", time=start)
    else:
        raise TypeError(f"Argument `layout` must be either a `string`, "
                        f"a `dict` or a `callable`, found `{type(layout)}`.")

    curves, lc = None, None
    if edge_use_curved:
        try:
            from ._utils import curved_edges

            logg.debug("DEBUG: Creating curved edges")
            curves = curved_edges(G,
                                  pos,
                                  self_loop_radius_frac,
                                  polarity="directed")
            lc = LineCollection(
                curves,
                colors="black",
                linewidths=np.clip(
                    np.ravel([v["weight"]
                              for v in G.edges.values()]) * edge_weight_scale,
                    0,
                    edge_width_limit,
                ),
                alpha=edge_alpha,
            )
        except ImportError as e:
            global _msg_shown
            if not _msg_shown:
                print(
                    str(e)[:-1],
                    "in order to use curved edges or specify `edge_use_curved=False`.",
                )
                _msg_shown = True

    for ax, keyloc, key, labs, er in zip(axes, keylocs, keys, labels,
                                         edge_reductions):
        label_col = {}  # dummy value

        if key in ("incoming", "outgoing", "self_loops"):
            if key in ("incoming", "outgoing"):
                vals = np.array(er(gdata,
                                   axis=int(key == "outgoing"))).flatten()
            else:
                vals = gdata.diagonal() if is_sparse else np.diag(gdata)
            node_v = dict(zip(pos.keys(), vals))
        else:
            label_col = getattr(data, keyloc)
            if key in label_col:
                node_v = dict(zip(pos.keys(), label_col[key]))
            else:
                raise RuntimeError(
                    f"Key `{key!r}` not found in `adata.{keyloc!r}`.")

        if labs is not None:
            if len(labs) != len(pos):
                raise RuntimeError(
                    f"Number of labels ({len(labels)}) and nodes ({len(pos)}) mismatch."
                )
            nx.draw_networkx_labels(
                G,
                pos,
                labels=labs if isinstance(labs, dict) else dict(
                    zip(pos.keys(), labs)),
                ax=ax,
                font_color=font_color,
                font_size=font_size,
            )

        if lc is not None and curves is not None:
            ax.add_collection(deepcopy(lc))  # copying necessary
            if show_arrows:
                plot_arrows(curves, G, pos, ax, edge_weight_scale)
        else:
            nx.draw_networkx_edges(
                G,
                pos,
                width=[
                    np.clip(
                        v["weight"] * edge_weight_scale,
                        _min_edge_weight,
                        edge_width_limit,
                    ) for _, v in G.edges.items()
                ],
                alpha=edge_alpha,
                edge_color="black",
                arrows=True,
                arrowstyle="-|>",
            )

        if key in label_col and is_categorical_dtype(label_col[key]):
            values = label_col[key]
            if keyloc in ("obs", "obsm"):
                values = values[ixs]
            categories = values.cat.categories
            color_key = _colors(key)
            if color_key in data.uns:
                mapper = dict(zip(categories, data.uns[color_key]))
            else:
                mapper = dict(
                    zip(categories, map(cat_cmap.get, range(len(categories)))))

            colors = []
            seen = set()

            for v in values:
                colors.append(mapper[v])
                seen.add(v)

            nodes_kwargs = dict(cmap=cat_cmap, node_color=colors)  # noqa
            if legend_loc is not None:
                x, y = pos[0]
                for label in sorted(seen):
                    ax.plot([x], [y], label=label, color=mapper[label])
                ax.legend(loc=legend_loc)
        else:
            values = list(node_v.values())
            vmin, vmax = np.min(values), np.max(values)
            nodes_kwargs = dict(  # noqa
                cmap=cont_cmap,
                node_color=values,
                vmin=vmin,
                vmax=vmax)

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="1.5%", pad=0.05)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          cmap=cont_cmap,
                                          norm=mpl.colors.Normalize(vmin=vmin,
                                                                    vmax=vmax))

        nx.draw_networkx_nodes(G,
                               pos,
                               node_size=node_size,
                               ax=ax,
                               **nodes_kwargs)

        ax.set_title(key)
        ax.axis("off")

    if save is not None:
        save_fig(fig, save)

    fig.show()
コード例 #9
0
def cluster_lineage(
    adata: AnnData,
    model: _model_type,
    genes: Sequence[str],
    lineage: str,
    final: bool = True,
    clusters: Optional[Sequence[str]] = None,
    n_points: int = 200,
    time_key: str = "latent_time",
    cluster_key: str = "louvain",
    norm: bool = True,
    recompute: bool = False,
    ncols: int = 3,
    sharey: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = "multiprocessing",
    pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
    neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
    louvain_kwargs: Dict = MappingProxyType({}),
    key_added: Optional[str] = None,
    save: Optional[Union[str, Path]] = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    show_progress_bar: bool = True,
    **kwargs,
) -> None:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential lineage
    drivers, identified e.g. by running :func:`cellrank.tl.gene_importance`.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/cluster_lineage.png
       :width: 400px
       :align: center

    Params
    ------
    adata : :class:`anndata.AnnData`
        Annotated data object.
    model
        Model to fit.

        - If a :class:`dict`, gene and lineage specific models can be specified. Use `'*'` to indicate
        all genes or lineages, for example `{'Map2': {'*': ...}, 'Dcx': {'Alpha': ..., '*': ...}}`.
    genes
        Genes in :paramref:`adata`.var_names to cluster.
    lineage_name
        Name of the lineage along which to cluster the genes.
    final
        Whether to consider cells going to final states or vice versa.
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered.
        Useful when plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in :paramref:`adata` `.obs` where the pseudotime is stored.
    cluster_key
        Key in :paramref:`adata` `.obs` where the clustering is stored.
    norm
        Whether to z-normalize each trend to have `0` mean, `1` variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    n_jobs
        Number of parallel jobs. If `-1`, use all available cores. If `None` or `1`, the execution is sequential.
    backend
        Which backend to use for multiprocessing.
        See :class:`joblib.Parallel` for valid options.
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    louvain_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain`.
    save
        Filename where to save the plot.
        If `None`, just shows the plot.
    figsize
        Size of the figure. If `None`, it will be set automatically.
    dpi
        Dots per inch.
    show_progress_bar
        Whether to show a progress bar tracking models fitted.
    kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.Model.prepare`.

    Returns
    -------
    None
        Plots the clusters of :paramref:`genes` for the given :paramref:`lineage_name`.
        Optionally saves the figure based on :paramref:`save`.

        Updates :paramref:`adata` `.uns` with the following key:

        lineage_{:paramref:`lineage_name`}_trend_{:paramref:`key_added`}:
            - :class:`anndata.AnnData` object of shape `len` (:paramref:`genes`) x :paramref:`n_points`
              containing the clustered genes.
    """

    lineage_key = str(LinKey.FORWARD if final else LinKey.BACKWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    check_collection(adata, genes, "var_names")

    key_to_add = f"lineage_{lineage}_trend"
    if key_added is not None:
        logg.debug(f"DEBUG: Adding key `{key_added!r}`")
        key_to_add += f"_{key_added}"

    if recompute or key_to_add not in adata.uns:
        kwargs["time_key"] = time_key  # kwargs for the model.prepare
        kwargs["n_test_points"] = n_points
        kwargs["final"] = final

        models = _create_models(model, genes, [lineage])
        if _is_any_gam_mgcv(models):
            backend = "multiprocessing"

        n_jobs = _get_n_cores(n_jobs, len(genes))

        start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
        trends = parallelize(
            _cl_process,
            genes,
            as_array=True,
            unit="gene",
            n_jobs=n_jobs,
            backend=backend,
            show_progress_bar=show_progress_bar,
        )(models, lineage, norm, **kwargs)
        logg.info("    Finish", time=start)

        trends = AnnData(np.vstack(trends))
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if n_points is not None and trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        pca_kwargs = dict(pca_kwargs)
        n_comps = pca_kwargs.pop("n_comps", 50)  # default value
        if n_comps > len(genes):
            n_comps = len(genes) - 1

        sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        louvain_kwargs = dict(louvain_kwargs)
        louvain_kwargs["key_added"] = cluster_key
        sc.tl.louvain(trends, **louvain_kwargs)

        adata.uns[key_to_add] = trends
    else:
        logg.info(f"Loading data from `adata.uns[{key_to_add}!r]`")
        trends = adata.uns[key_to_add]

    if clusters is None:
        if cluster_key not in trends.obs:
            raise KeyError(f"Invalid cluster key `{cluster_key!r}`.")
        clusters = trends.obs[cluster_key].cat.categories

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs[cluster_key] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)
コード例 #10
0
    def plot(
        self,
        figsize: Tuple[float, float] = (15, 10),
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Tuple[float, float] = None,
        abs_prob_cmap: mcolors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        color: str = "black",
        alpha: float = 0.8,
        lineage_alpha: float = 0.2,
        title: Optional[str] = None,
        size: int = 15,
        lw: float = 2,
        show_cbar: bool = True,
        margins: float = 0.015,
        xlabel: str = "Pseudotime",
        ylabel: str = "Expression",
        show_conf_int: bool = True,
        dpi: int = None,
        fig: mpl.figure.Figure = None,
        ax: mpl.axes.Axes = None,
        return_fig: bool = False,
        save: Optional[str] = None,
    ) -> Optional[mpl.figure.Figure]:
        """
        Plots the smoothed gene expression.

        Params
        ------
        figsize
            Size of the figure.
        same_plot
            Whether to plot all trends in the same plot.
        hide_cells
            Whether to hide the cells.
        perc
            Percentile by which to clip the absorption probabilities.
        abs_prob_cmap
            Colormap to use when coloring in the absorption probabilities.
        cell_color
            Color for the cells when not coloring absorption probabilities.
        color
            Color for the lineages.
        alpha
            Alpha channel for cells.
        lineage_alpha
            Alpha channel for lineage confidence intervals.
        title
            Title of the plot.
        size
            Size of the points.
        lw
            Line width for the smoothed values.
        show_cbar
            Whether to show colorbar.
        margins
            Margins around the plot.
        xlabel
            Label on the x-axis.
        ylabel
            Label on the y-axis.
        show_conf_int
            Whether to show the confidence interval.
        dpi
            Dots per inch.
        fig
            Figure to use, if `None`, create a new one.
        ax: :class:`matplotlib.axes.Axes`
            Ax to use, if `None`, create a new one.
        return_fig
            If `True`, return the figure object.
        save
            Filename where to save the plot.
            If `None`, just shows the plots.

        Returns
        -------
        None
            Nothing, just plots the fitted model.
        """

        if fig is None or ax is None:
            fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)

        if dpi is not None:
            fig.set_dpi(dpi)

        vmin, vmax = _minmax(self.w, perc)
        if not hide_cells:
            _ = ax.scatter(
                self.x_all.squeeze(),
                self.y_all.squeeze(),
                c=cell_color if same_plot or np.allclose(self.w_all, 1.0) else
                self.w_all.squeeze(),
                s=size,
                cmap=abs_prob_cmap,
                vmin=vmin,
                vmax=vmax,
                alpha=alpha,
            )

        ax.plot(self.x_test, self.y_test, color=color, lw=lw, label=title)
        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.margins(margins)

        if show_conf_int and self.conf_int is not None:
            ax.fill_between(
                self.x_test.squeeze(),
                self.conf_int[:, 0],
                self.conf_int[:, 1],
                alpha=lineage_alpha,
                color=color,
                linestyle="--",
            )

        if show_cbar and not hide_cells and not same_plot:
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            cax, _ = mpl.colorbar.make_axes(ax, aspect=200)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          norm=norm,
                                          cmap=abs_prob_cmap,
                                          label="Absorption probability")

        if save is not None:
            save_fig(fig, save)

        if return_fig:
            return fig
コード例 #11
0
ファイル: _base_estimator.py プロジェクト: dpeerlab/cellrank
    def _plot_real_spectrum(
        self,
        dpi: int = 100,
        figsize: Optional[Tuple[float, float]] = None,
        legend_loc: Optional[str] = None,
        title: Optional[str] = None,
        save: Optional[Union[str, Path]] = None,
    ) -> None:
        """
        Plot the real part of the top eigenvalues.

        Params
        ------
        dpi
            Dots per inch.
        figsize
            Size of the figure.
        legend_loc
            Location parameter for the legend.
        title
            Figure title.
        save
            Filename where to save the plots. If `None`, just shows the plot.

        Returns
        -------
        None
            Nothing, just plots the spectrum.
        """

        if self._eig is None:
            logg.warning(
                "No eigendecomposition found, computing with default parameters"
            )
            self.compute_eig()

        # Obtain the eigendecomposition, create the color code
        D, params = self._eig["D"], self._eig["params"]
        D_real, D_imag = D.real, D.imag
        ixs = np.arange(len(D))
        mask = D_imag == 0

        # plot the top eigenvalues
        fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize)
        ax.scatter(ixs[mask],
                   D_real[mask],
                   marker="o",
                   label="real eigenvalue")
        ax.scatter(ixs[~mask],
                   D_real[~mask],
                   marker="o",
                   label="complex eigenvalue")

        # add dashed line for the eigengap, ticks, labels, title and legend
        ax.axvline(self._eig["eigengap"], label="eigengap", ls="--")

        ax.set_xlabel("index")
        ax.set_xticks(range(len(D)))

        ax.set_ylabel(r"Re($\lambda_i$)")
        key = "real part" if params["which"] == "LR" else "magnitude"

        # set the title
        if title is None:
            fig_title = (
                f"real part of top {params['k']} eigenvalues according to their {key}"
            )
        else:
            fig_title = title

        ax.set_title(fig_title)

        ax.legend(loc=legend_loc)

        if save is not None:
            save_fig(fig, save)

        fig.show()
コード例 #12
0
ファイル: _base_estimator.py プロジェクト: dpeerlab/cellrank
    def plot_spectrum(
        self,
        real_only: bool = False,
        dpi: int = 100,
        figsize: Optional[Tuple[float, float]] = (5, 5),
        legend_loc: Optional[str] = None,
        title: Optional[str] = None,
        save: Optional[Union[str, Path]] = None,
    ) -> None:
        """
        Plot the top eigenvalues in complex plane.

        Params
        ------
        real_only
            Whether to plot only the real part of the spectrum.
        dpi
            Dots per inch.
        figsize
            Size of the figure.
        legend_loc
            Location parameter for the legend
        title
            Figure title
        save
            Filename where to save the plots. If `None`, just shows the plot.

        Returns
        -------
        None
            Nothing, just plots the spectrum in complex plane.
        """

        # define a function to make the data limits rectangular
        def adapt_range(min_, max_, range_):
            return (
                min_ + (max_ - min_) / 2 - range_ / 2,
                min_ + (max_ - min_) / 2 + range_ / 2,
            )

        if self._eig is None:
            logg.warning(
                "No eigendecomposition found, computing with default parameters"
            )
            self.compute_eig()

        if real_only:
            self._plot_real_spectrum(dpi=dpi,
                                     figsize=figsize,
                                     legend_loc=legend_loc,
                                     save=save,
                                     title=title)
            return

        D = self._eig["D"]
        params = self._eig["params"]

        # create fiture and axes
        fig, ax = plt.subplots(nrows=1, ncols=1, dpi=dpi, figsize=figsize)

        # get the original data ranges
        lam_x, lam_y = D.real, D.imag
        x_min, x_max = np.min(lam_x), np.max(lam_x)
        y_min, y_max = np.min(lam_y), np.max(lam_y)
        x_range, y_range = x_max - x_min, y_max - y_min
        final_range = np.max([x_range, y_range]) + 0.05

        x_min_, x_max_ = adapt_range(x_min, x_max, final_range)
        y_min_, y_max_ = adapt_range(y_min, y_max, final_range)

        # plot the data and the unit circle
        ax.scatter(D.real, D.imag, marker=".", label="eigenvalue")
        t = np.linspace(0, 2 * np.pi, 500)
        x_circle, y_circle = np.sin(t), np.cos(t)
        ax.plot(x_circle, y_circle, "k-", label="unit circle")

        # set labels, ranges and legend
        ax.set_xlabel(r"Re($\lambda$)")
        ax.set_xlim(x_min_, x_max_)
        ax.set_ylabel(r"Im($\lambda$)")
        ax.set_ylim(y_min_, y_max_)
        key = "real part" if params["which"] == "LR" else "magnitude"

        # set the figure title
        if title is None:
            fig_title = f"top {params['k']} eigenvalues according to their {key}"
        else:
            fig_title = title
        ax.set_title(fig_title)

        # set legend location
        ax.legend(loc=legend_loc)

        if save is not None:
            save_fig(fig, save)

        fig.show()