Beispiel #1
0
def composition(
    adata: AnnData,
    key: str,
    fontsize: Optional[str] = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[float] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> None:
    """
    Plot a pie chart for categorical annotation.

    .. image:: https://raw.githubusercontent.com/theislab/cellrank/master/resources/images/composition.png
       :width: 400px
       :align: center

    Parameters
    ----------
    %(adata)s
    key
        Key in ``adata.obs`` containing categorical observation.
    fontsize
        Font size for the pie chart labels.
    %(plotting)s
    **kwargs
        Keyworded arguments for :func:`matplotlib.pyplot.pie`.

    Returns
    -------
    %(just_plots)s
    """

    if key not in adata.obs:
        raise KeyError(f"Key `{key!r}` not found in `adata.obs`.")
    if not is_categorical_dtype(adata.obs[key]):
        raise TypeError(
            f"Observation `adata.obs[{key!r}]` is not categorical.")

    cats = adata.obs[key].cat.categories
    colors = adata.uns.get(f"{key}_colors", None)
    x = [np.sum(adata.obs[key] == cl) for cl in cats]
    cats_frac = x / np.sum(x)

    # plot these fractions in a pie plot
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    ax.pie(
        x=cats_frac,
        labels=cats,
        colors=colors,
        textprops={"fontsize": fontsize},
        **kwargs,
    )
    ax.set_title(f"composition by {key}")

    if save is not None:
        save_fig(fig, save)

    fig.show()
Beispiel #2
0
    def plot_spectrum(
        self,
        n: Optional[int] = None,
        real_only: bool = False,
        show_eigengap: bool = True,
        show_all_xticks: bool = True,
        legend_loc: Optional[str] = None,
        title: Optional[str] = None,
        figsize: Optional[Tuple[float, float]] = (5, 5),
        dpi: int = 100,
        save: Optional[Union[str, Path]] = None,
        marker: str = ".",
        **kwargs,
    ) -> None:
        """
        Plot the top eigenvalues in real or complex plane.

        Parameters
        ----------
        n
            Number of eigenvalues to show. If `None`, show all that have been computed.
        real_only
            Whether to plot only the real part of the spectrum.
        show_eigengap
            When `real_only=True`, this determines whether to show the inferred eigengap as
            a dotted line.
        show_all_xticks
            When `real_only=True`, this determines whether to show the indices of all eigenvalues
            on the x-axis.
        legend_loc
            Location parameter for the legend.
        title
            Title of the figure.
        %(plotting)s
        marker
            Marker symbol used, valid options can be found in :mod:`matplotlib.markers`.
        **kwargs
            Keyword arguments for :func:`matplotlib.pyplot.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        eig = getattr(self, P.EIG.s)
        if eig is None:
            raise RuntimeError(
                f"Compute `.{P.EIG}` first as `.{F.COMPUTE.fmt(P.EIG)}()`.")
        if n is None:
            n = len(eig["D"])
        elif n <= 0:
            raise ValueError(f"Expected `n` to be > 0, found `{n}`.")

        if real_only:
            fig = self._plot_real_spectrum(
                n,
                show_eigengap=show_eigengap,
                show_all_xticks=show_all_xticks,
                dpi=dpi,
                figsize=figsize,
                legend_loc=legend_loc,
                title=title,
                marker=marker,
                **kwargs,
            )
        else:
            fig = self._plot_complex_spectrum(
                n,
                dpi=dpi,
                figsize=figsize,
                legend_loc=legend_loc,
                title=title,
                marker=marker,
                **kwargs,
            )

        if save:
            save_fig(fig, save)

        fig.show()
Beispiel #3
0
    def plot_macrostate_composition(
        self,
        key: str,
        width: float = 0.8,
        title: Optional[str] = None,
        labelrot: float = 45,
        legend_loc: Optional[str] = "upper right out",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        show: bool = True,
    ) -> Optional[Axes]:
        """
        Plot stacked histogram of macrostates over categorical annotations.

        Parameters
        ----------
        %(adata)s
        key
            Key from :attr:`adata` ``.obs`` containing categorical annotations.
        width
            Bar width in `[0, 1]`.
        title
            Title of the figure. If `None`, create one automatically.
        labelrot
            Rotation of labels on x-axis.
        legend_loc
            Position of the legend. If `None`, don't show legend.
        %(plotting)s
        show
            If `False`, return :class:`matplotlib.pyplot.Axes`.

        Returns
        -------
        :class:`matplotlib.pyplot.Axes`
            The axis object if ``show=False``.
        %(just_plots)s
        """
        from cellrank.pl._utils import _position_legend

        macrostates = self._get(P.MACRO)
        if macrostates is None:
            raise RuntimeError(
                "Compute macrostates first as `.compute_macrostates()`.")
        if key not in self.adata.obs:
            raise KeyError(f"Key `{key}` not found in `adata.obs`.")
        if not is_categorical_dtype(self.adata.obs[key]):
            raise TypeError(
                f"Expected `adata.obs[{key!r}]` to be `categorical`, "
                f"found `{infer_dtype(self.adata.obs[key])}`.")

        mask = ~macrostates.isnull()
        df = (pd.DataFrame({
            "macrostates": macrostates,
            key: self.adata.obs[key]
        })[mask].groupby([key, "macrostates"]).size())
        try:
            cats_colors = self.adata.uns[f"{key}_colors"]
        except KeyError:
            cats_colors = _create_categorical_colors(
                len(self.adata.obs[key].cat.categories))
        cat_color_mapper = dict(
            zip(self.adata.obs[key].cat.categories, cats_colors))
        x_indices = np.arange(len(macrostates.cat.categories))
        bottom = np.zeros_like(x_indices, dtype=np.float32)

        width = min(1, max(0, width))
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
        for cat, color in cat_color_mapper.items():
            frequencies = df.loc[cat]
            # do not add to legend if category is missing
            if np.sum(frequencies) > 0:
                ax.bar(
                    x_indices,
                    frequencies,
                    width,
                    label=cat,
                    color=color,
                    bottom=bottom,
                    ec="black",
                    lw=0.5,
                )
                bottom += np.array(frequencies)

        ax.set_xticks(x_indices)
        ax.set_xticklabels(
            # assuming at least 1 category
            frequencies.index,
            rotation=labelrot,
            ha="center" if labelrot in (0, 90) else "right",
        )
        y_max = bottom.max()
        ax.set_ylim([0, y_max + 0.05 * y_max])
        ax.set_yticks(np.linspace(0, y_max, 5))
        ax.margins(0.05)

        ax.set_xlabel("macrostate")
        ax.set_ylabel("frequency")
        if title is None:
            title = f"distribution over {key}"
        ax.set_title(title)
        if legend_loc not in (None, "none"):
            _position_legend(ax, legend_loc=legend_loc)

        if save is not None:
            save_fig(fig, save)

        if not show:
            return ax
Beispiel #4
0
    def plot_coarse_T(
        self,
        show_stationary_dist: bool = True,
        show_initial_dist: bool = False,
        cmap: Union[str, mcolors.ListedColormap] = "viridis",
        xtick_rotation: float = 45,
        annotate: bool = True,
        show_cbar: bool = True,
        title: Optional[str] = None,
        figsize: Tuple[float, float] = (8, 8),
        dpi: int = 80,
        save: Optional[Union[Path, str]] = None,
        text_kwargs: Mapping[str, Any] = MappingProxyType({}),
        **kwargs,
    ) -> None:
        """
        Plot the coarse-grained transition matrix between macrostates.

        Parameters
        ----------
        show_stationary_dist
            Whether to show the stationary distribution, if present.
        show_initial_dist
            Whether to show the initial distribution.
        cmap
            Colormap to use.
        xtick_rotation
            Rotation of ticks on the x-axis.
        annotate
            Whether to display the text on each cell.
        show_cbar
            Whether to show colorbar.
        title
            Title of the figure.
        %(plotting)s
        text_kwargs
            Keyword arguments for :func:`matplotlib.pyplot.text`.
        kwargs
            Keyword arguments for :func:`matplotlib.pyplot.imshow`.

        Returns
        -------
        %(just_plots)s
        """
        def stylize_dist(ax,
                         data: np.ndarray,
                         xticks_labels: Union[List[str], Tuple[str]] = ()):
            _ = ax.imshow(data, aspect="auto", cmap=cmap, norm=norm)
            for spine in ax.spines.values():
                spine.set_visible(False)

            if xticks_labels is not None:
                ax.set_xticks(np.arange(data.shape[1]))
                ax.set_xticklabels(xticks_labels)
                plt.setp(
                    ax.get_xticklabels(),
                    rotation=xtick_rotation,
                    ha="right",
                    rotation_mode="anchor",
                )
            else:
                ax.set_xticks([])
                ax.tick_params(which="both",
                               top=False,
                               right=False,
                               bottom=False,
                               left=False)

            ax.set_yticks([])

        def annotate_heatmap(im, valfmt: str = "{x:.2f}"):
            # modified from matplotlib's site

            data = im.get_array()
            kw = {"ha": "center", "va": "center"}
            kw.update(**text_kwargs)

            # Get the formatter in case a string is supplied
            if isinstance(valfmt, str):
                valfmt = mpl.ticker.StrMethodFormatter(valfmt)

            # Loop over the data and create a `Text` for each "pixel".
            # Change the text's color depending on the data.
            texts = []
            for i in range(data.shape[0]):
                for j in range(data.shape[1]):
                    kw.update(
                        color=_get_black_or_white(im.norm(data[i, j]), cmap))
                    text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
                    texts.append(text)

        def annotate_dist_ax(ax, data: np.ndarray, valfmt: str = "{x:.2f}"):
            if ax is None:
                return

            if isinstance(valfmt, str):
                valfmt = mpl.ticker.StrMethodFormatter(valfmt)

            kw = {"ha": "center", "va": "center"}
            kw.update(**text_kwargs)

            for i, val in enumerate(data):
                kw.update(color=_get_black_or_white(im.norm(val), cmap))
                ax.text(
                    i,
                    0,
                    valfmt(val, None),
                    **kw,
                )

        coarse_T = self._get(P.COARSE_T)
        coarse_stat_d = self._get(P.COARSE_STAT_D)
        coarse_init_d = self._get(P.COARSE_INIT_D)

        if coarse_T is None:
            raise RuntimeError(
                "Compute coarse-grained transition matrix first as `.compute_macrostates()` with `n_states > 1`."
            )

        if show_stationary_dist and coarse_stat_d is None:
            logg.warning("Coarse stationary distribution is `None`, ignoring")
            show_stationary_dist = False
        if show_initial_dist and coarse_init_d is None:
            logg.warning("Coarse initial distribution is `None`, ignoring")
            show_initial_dist = False

        hrs, wrs = [1], [1]
        if show_stationary_dist:
            hrs += [0.05]
        if show_initial_dist:
            hrs += [0.05]
        if show_cbar:
            wrs += [0.025]

        dont_show_dist = not show_initial_dist and not show_stationary_dist

        fig = plt.figure(constrained_layout=False, figsize=figsize, dpi=dpi)
        gs = plt.GridSpec(
            1 + show_stationary_dist + show_initial_dist,
            1 + show_cbar,
            height_ratios=hrs,
            width_ratios=wrs,
            wspace=0.05,
            hspace=0.05,
        )
        if isinstance(cmap, str):
            cmap = plt.get_cmap(cmap)

        ax = fig.add_subplot(gs[0, 0])
        cax = fig.add_subplot(gs[:1, -1]) if show_cbar else None
        init_ax, stat_ax = None, None

        labels = list(self.coarse_T.columns)

        tmp = coarse_T
        if show_initial_dist:
            tmp = np.c_[tmp, coarse_stat_d]
        if show_initial_dist:
            tmp = np.c_[tmp, coarse_init_d]

        minn, maxx = np.nanmin(tmp), np.nanmax(tmp)
        norm = mpl.colors.Normalize(vmin=minn, vmax=maxx)

        if show_stationary_dist:
            stat_ax = fig.add_subplot(gs[1, 0])
            stylize_dist(
                stat_ax,
                np.array(coarse_stat_d).reshape(1, -1),
                xticks_labels=labels if not show_initial_dist else None,
            )
            stat_ax.yaxis.set_label_position("right")
            stat_ax.set_ylabel("stationary dist",
                               rotation=0,
                               ha="left",
                               va="center")

        if show_initial_dist:
            init_ax = fig.add_subplot(gs[show_stationary_dist +
                                         show_initial_dist, 0])
            stylize_dist(init_ax,
                         np.array(coarse_init_d).reshape(1, -1),
                         xticks_labels=labels)

            init_ax.yaxis.set_label_position("right")
            init_ax.set_ylabel("initial dist",
                               rotation=0,
                               ha="left",
                               va="center")

        im = ax.imshow(coarse_T, aspect="auto", cmap=cmap, norm=norm, **kwargs)
        ax.set_title(
            "coarse-grained transition matrix" if title is None else title)

        if cax is not None:
            _ = mpl.colorbar.ColorbarBase(
                cax,
                cmap=cmap,
                norm=norm,
                ticks=np.linspace(minn, maxx, 10),
                format="%0.3f",
            )

        ax.set_yticks(np.arange(coarse_T.shape[0]))
        ax.set_yticklabels(labels)

        ax.tick_params(
            top=False,
            bottom=dont_show_dist,
            labeltop=False,
            labelbottom=dont_show_dist,
        )

        for spine in ax.spines.values():
            spine.set_visible(False)

        if dont_show_dist:
            ax.set_xticks(np.arange(coarse_T.shape[1]))
            ax.set_xticklabels(labels)
            plt.setp(
                ax.get_xticklabels(),
                rotation=xtick_rotation,
                ha="right",
                rotation_mode="anchor",
            )
        else:
            ax.set_xticks([])

        ax.set_yticks(np.arange(coarse_T.shape[0] + 1) - 0.5, minor=True)
        ax.tick_params(which="minor",
                       bottom=dont_show_dist,
                       left=False,
                       top=False)

        if annotate:
            annotate_heatmap(im)
            if show_stationary_dist:
                annotate_dist_ax(stat_ax, coarse_stat_d.values)
            if show_initial_dist:
                annotate_dist_ax(init_ax, coarse_init_d)

        if save:
            save_fig(fig, save)
Beispiel #5
0
def gene_trends(
    adata: AnnData,
    model: _input_model_type,
    genes: Union[str, Sequence[str]],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    data_key: str = "X",
    time_key: str = "latent_time",
    transpose: bool = False,
    time_range: Optional[Union[_time_range_type,
                               List[_time_range_type]]] = None,
    callback: _callback_type = None,
    conf_int: Union[bool, float] = True,
    same_plot: bool = False,
    hide_cells: bool = False,
    perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float,
                                                             float]]]] = None,
    lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
    abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
    cell_color: Optional[str] = None,
    cell_alpha: float = 0.6,
    lineage_alpha: float = 0.2,
    size: float = 15,
    lw: float = 2,
    cbar: bool = True,
    margins: float = 0.015,
    sharex: Optional[Union[str, bool]] = None,
    sharey: Optional[Union[str, bool]] = None,
    gene_as_title: Optional[bool] = None,
    legend_loc: Optional[str] = "best",
    obs_legend_loc: Optional[str] = "best",
    ncols: int = 2,
    suptitle: Optional[str] = None,
    return_models: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    plot_kwargs: Mapping = MappingProxyType({}),
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression,
    where we take the lineage weights into account in the loss function.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages to plot. If `None`, plot all lineages.
    %(backward)s
    data_key
        Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    transpose
        If ``same_plot=True``, group the trends by ``lineages`` instead of ``genes``. This enforces ``hide_cells=True``.
        If ``same_plot=False``, show ``lineages`` in rows and ``genes`` in columns.
    %(model_callback)s
    conf_int
        Whether to compute and show confidence interval. If the :paramref:`model` is :class:`cellrank.ul.models.GAMR`,
        it can also specify the confidence level, the default is `0.95`.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all cells.
    perc
        Percentile for colors. Valid values are in interval `[0, 100]`.
        This can improve visualization. Can be specified individually for each lineage.
    lineage_cmap
        Categorical colormap to use when coloring in the lineages. If `None` and ``same_plot``,
        use the corresponding colors in ``adata.uns``, otherwise use `'black'`.
    abs_prob_cmap
        Continuous colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when ``same_plot=False``.
    cell_color
        Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names` used for coloring the cells.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``.
    margins
        Margins around the plot.
    sharex
        Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`.
    sharey
        Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`.
    gene_as_title
        Whether to show gene names as titles instead on y-axis.
    legend_loc
        Location of the legend displaying lineages. Only used when `same_plot=True`.
    obs_legend_loc
        Location of the legend when ``cell_color`` corresponds to a categorical variable.
    ncols
        Number of columns of the plot when plotting multiple genes. Only used when ``same_plot=True``.
    suptitle
        Suptitle of the figure.
    %(return_models)s
    %(parallel)s
    %(plotting)s
    plot_kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    if data_key != "obs":
        _check_collection(adata,
                          genes,
                          "var_names",
                          use_raw=kwargs.get("use_raw", False))
    else:
        _check_collection(adata,
                          genes,
                          "obs",
                          use_raw=kwargs.get("use_raw", False))

    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(ln is None
             for ln in lineages):  # no lineage, all the weights are 1
        lineages = [None]
        cbar = False
        logg.debug("All lineages are `None`, setting the weights to `1`")
    lineages = _unique_order_preserving(lineages)

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    kwargs["time_key"] = time_key
    kwargs["data_key"] = data_key
    kwargs["backward"] = backward
    kwargs["conf_int"] = conf_int  # prepare doesnt take or need this
    models = _create_models(model, genes, lineages)

    all_models, models, genes, lineages = _fit_bulk(
        models,
        _create_callbacks(adata, callback, genes, lineages, **kwargs),
        genes,
        lineages,
        time_range,
        return_models=True,
        filter_all_failed=False,
        parallel_kwargs={
            "show_progress_bar": show_progress_bar,
            "n_jobs": _get_n_cores(n_jobs, len(genes)),
            "backend": _get_backend(models, backend),
        },
        **kwargs,
    )

    lineages = sorted(lineages)
    tmp = adata.obsm[ln_key][lineages].colors
    if lineage_cmap is None and not transpose:
        lineage_cmap = tmp

    plot_kwargs = dict(plot_kwargs)
    plot_kwargs["obs_legend_loc"] = obs_legend_loc
    if transpose:
        all_models = pd.DataFrame(all_models).T.to_dict()
        models = pd.DataFrame(models).T.to_dict()
        genes, lineages = lineages, genes
        hide_cells = same_plot or hide_cells
    else:
        # information overload otherwise
        plot_kwargs["lineage_probability"] = False
        plot_kwargs["lineage_probability_conf_int"] = False

    tmp = pd.DataFrame(models).T.astype(bool)
    start_rows = np.argmax(tmp.values, axis=0)
    end_rows = tmp.shape[0] - np.argmax(tmp[::-1].values, axis=0) - 1

    if same_plot:
        gene_as_title = True if gene_as_title is None else gene_as_title
        sharex = "all" if sharex is None else sharex
        if sharey is None:
            sharey = "row" if plot_kwargs.get("lineage_probability",
                                              False) else "none"
        ncols = len(genes) if ncols >= len(genes) else ncols
        nrows = int(np.ceil(len(genes) / ncols))
    else:
        gene_as_title = False if gene_as_title is None else gene_as_title
        sharex = "col" if sharex is None else sharex
        if sharey is None:
            sharey = ("row" if not hide_cells or plot_kwargs.get(
                "lineage_probability", False) else "none")
        nrows = len(genes)
        ncols = len(lineages)

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = time_key

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=sharex,
        sharey=sharey,
        figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize,
        tight_layout=True,
        dpi=dpi,
    )
    axes = np.reshape(axes, (nrows, ncols))

    cnt = 0
    plot_kwargs["obs_legend_loc"] = None if same_plot else obs_legend_loc

    logg.info("Plotting trends")
    for row in range(len(axes)):
        for col in range(len(axes[row])):
            if cnt >= len(genes):
                break
            gene = genes[cnt]
            if (same_plot and plot_kwargs.get("lineage_probability", False)
                    and transpose):
                lpc = adata.obsm[ln_key][gene].colors[0]
            else:
                lpc = None

            if same_plot:
                plot_kwargs["obs_legend_loc"] = (obs_legend_loc if row == 0
                                                 and col == len(axes[0]) - 1
                                                 else None)

            _trends_helper(
                models,
                gene=gene,
                lineage_names=lineages,
                transpose=transpose,
                same_plot=same_plot,
                hide_cells=hide_cells,
                perc=perc,
                lineage_cmap=lineage_cmap,
                abs_prob_cmap=abs_prob_cmap,
                lineage_probability_color=lpc,
                cell_color=cell_color,
                alpha=cell_alpha,
                lineage_alpha=lineage_alpha,
                size=size,
                lw=lw,
                cbar=cbar,
                margins=margins,
                sharey=sharey,
                gene_as_title=gene_as_title,
                legend_loc=legend_loc,
                figsize=figsize,
                fig=fig,
                axes=axes[row, col] if same_plot else axes[cnt],
                show_ylabel=col == 0,
                show_lineage=same_plot or (cnt == start_rows),
                show_xticks_and_label=((row + 1) * ncols + col >= len(genes))
                if same_plot else (cnt == end_rows),
                **plot_kwargs,
            )
            # plot legend on the 1st plot
            cnt += 1

            if not same_plot:
                plot_kwargs["obs_legend_loc"] = None

    if same_plot and (col != ncols):
        for ax in np.ravel(axes)[cnt:]:
            ax.remove()

    fig.suptitle(suptitle, y=1.05)

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models
Beispiel #6
0
def graph(
        data: Union[AnnData, np.ndarray, spmatrix],
        graph_key: Optional[str] = None,
        ixs: Optional[np.array] = None,
        layout: Union[str, Dict, Callable] = "umap",
        keys: Sequence[KEYS] = ("incoming", ),
        keylocs: Union[KEYLOCS, Sequence[KEYLOCS]] = "uns",
        node_size: float = 400,
        labels: Optional[Union[Sequence[str], Sequence[Sequence[str]]]] = None,
        top_n_edges: Optional[Union[int, Tuple[int, bool, str]]] = None,
        self_loops: bool = True,
        self_loop_radius_frac: Optional[float] = None,
        filter_edges: Optional[Tuple[float, float]] = None,
        edge_reductions: Union[Callable, Sequence[Callable]] = np.sum,
        edge_weight_scale: float = 10,
        edge_width_limit: Optional[float] = None,
        edge_alpha: float = 1.0,
        edge_normalize: bool = False,
        edge_use_curved: bool = True,
        show_arrows: bool = True,
        font_size: int = 12,
        font_color: str = "black",
        color_nodes: bool = True,
        cat_cmap: ListedColormap = cm.Set3,
        cont_cmap: ListedColormap = cm.viridis,
        legend_loc: Optional[str] = "best",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        layout_kwargs: Dict = MappingProxyType({}),
) -> None:
    """
    Plot a graph, visualizing incoming and outgoing edges or self-transitions.

    This is a utility function to look in more detail at the transition matrix in areas of interest, e.g. around an
    endpoint of development. This function is meant to visualise a small subset of nodes (~100-500) and the most likely
    transitions between them. Note that limiting edges visualized using ``top_n_edges`` will speed things up,
    as well as reduce the visual clutter.

    Parameters
    ----------
    data
        The graph data to be plotted.
    graph_key
        Key in ``adata.obsp`` or ``adata.uns`` where the graph is stored. Only used
        when ``data`` is :class:`~anndata.Anndata` object.
    ixs
        Subset of indices of the graph to visualize.
    layout
        Layout to use for graph drawing.

        - If :class:`str`, search for embedding in ``adata.obsm['X_{layout}']``.
          Use ``layout_kwargs={'components': [0, 1]}`` to select components.
        - If :class:`dict`, keys should be values in interval ``[0, len(ixs))``
          and values `(x, y)` pairs corresponding to node positions.
    keys
        Keys in ``adata.obs``, ``adata.obsm`` or ``adata.obsp`` to color the nodes.

        - If `'incoming'`, `'outgoing'` or `'self_loops'`, visualize reduction (see ``edge_reductions``)
          for each node based on incoming or outgoing edges, respectively.
    keylocs
        Locations of ``keys``. Can be any attribute of ``data`` if it's :class:`anndata.AnnData` object.
    node_size
        Size of the nodes.
    labels
        Labels of the nodes.
    top_n_edges
        Either top N outgoing edges in descending order or a tuple
        ``(top_n_edges, in_ascending_order, {'incoming', 'outgoing'})``. If `None`, show all edges.
    self_loops
        Whether visualize self transitions and also to consider them in ``top_n_edges``.
    self_loop_radius_frac
        Fraction of a unit circle to visualize self transitions. If `None`, use ``node_size / 1000``.
    filter_edges
        Whether to remove all edges not in `[min, max]` interval.
    edge_reductions
        Aggregation function to use when coloring nodes by edge weights.
    edge_weight_scale
        Number by which to scale the width of the edges. Useful when the weights are small.
    edge_width_limit
        Upper bound for the width of the edges. Useful when weights are unevenly distributed.
    edge_alpha
        Alpha channel value for edges and arrows.
    edge_normalize
        If `True`, normalize edges to `[0, 1]` interval prior to applying any scaling or truncation.
    edge_use_curved
        If `True`, use curved edges. This can improve visualization at a small performance cost.
    show_arrows
        Whether to show the arrows. Setting this to `False` may dramatically speed things up.
    font_size
        Font size for node labels.
    font_color
        Label color of the nodes.
    color_nodes
        Whether to color the nodes
    cat_cmap
        Categorical colormap used when ``keys`` contain categorical variables.
    cont_cmap
        Continuous colormap used when ``keys`` contain continuous variables.
    legend_loc
        Location of the legend.
    %(plotting)s
    layout_kwargs
        Additional kwargs for ``layout``.

    Returns
    -------
    %(just_plots)s
    """

    from anndata import AnnData as _AnnData

    import networkx as nx

    def plot_arrows(curves, G, pos, ax, edge_weight_scale):
        for line, (edge, val) in zip(curves, G.edges.items()):
            if edge[0] == edge[1]:
                continue

            mask = (~np.isnan(line)).all(axis=1)
            line = line[mask, :]
            if not len(line):  # can be all NaNs
                continue

            line = line.reshape((-1, 2))
            X, Y = line[:, 0], line[:, 1]

            node_start = pos[edge[0]]
            # reverse
            if np.where(np.isclose(node_start - line,
                                   [0, 0]).all(axis=1))[0][0]:
                X, Y = X[::-1], Y[::-1]

            mid = len(X) // 2
            posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2])  # noqa

            arrow = FancyArrowPatch(
                posA=posA,
                posB=posB,
                # we clip because too small values
                # cause it to crash
                arrowstyle=ArrowStyle.CurveFilledB(
                    head_length=np.clip(
                        val["weight"] * edge_weight_scale * 4,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                    head_width=np.clip(
                        val["weight"] * edge_weight_scale * 2,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                ),
                color="k",
                zorder=float("inf"),
                alpha=edge_alpha,
                linewidth=0,
            )
            ax.add_artist(arrow)

    def normalize_weights():
        weights = np.array([v["weight"] for v in G.edges.values()])
        minn = np.min(weights)
        weights = (weights - minn) / (np.max(weights) - minn)
        for v, w in zip(G.edges.values(), weights):
            v["weight"] = w

    def remove_top_n_edges():
        if top_n_edges is None:
            return

        if isinstance(top_n_edges, (tuple, list)):
            to_keep, ascending, group_by = top_n_edges
        else:
            to_keep, ascending, group_by = top_n_edges, False, "out"

        if group_by not in ("incoming", "outgoing"):
            raise ValueError(
                "Argument `groupby` in `top_n_edges` must be either `'incoming`' or `'outgoing'`."
            )

        source, target = zip(*G.edges)
        weights = [v["weight"] for v in G.edges.values()]
        tmp = pd.DataFrame({
            "outgoing": source,
            "incoming": target,
            "w": weights
        })

        if not self_loops:
            # remove self loops
            tmp = tmp[tmp["incoming"] != tmp["outgoing"]]

        to_keep = set(
            map(
                tuple,
                tmp.groupby(group_by).apply(
                    lambda g: g.sort_values("w", ascending=ascending).take(
                        range(min(to_keep, len(g)))))[["outgoing",
                                                       "incoming"]].values,
            ))

        for e in list(G.edges):
            if e not in to_keep:
                G.remove_edge(*e)

    def remove_low_weight_edges():
        if filter_edges is None or filter_edges == (None, None):
            return

        minn, maxx = filter_edges
        minn = minn if minn is not None else -np.inf
        maxx = maxx if maxx is not None else np.inf

        for e, attr in list(G.edges.items()):
            if attr["weight"] < minn or attr["weight"] > maxx:
                G.remove_edge(*e)

    _min_edge_weight = 0.00001

    if edge_width_limit is None:
        logg.debug("Not limiting width of edges")
        edge_width_limit = float("inf")

    if self_loop_radius_frac is None:
        self_loop_radius_frac = (node_size /
                                 2000 if node_size >= 200 else node_size /
                                 1000)
        logg.debug(
            f"Setting self loop radius fraction to `{self_loop_radius_frac}`")

    if not isinstance(keylocs, (tuple, list)):
        keylocs = [keylocs] * len(keys)
    elif len(keylocs) == 1:
        keylocs = keylocs * 3
    elif all(map(lambda k: k in ("incoming", "outgoing", "self_loops"), keys)):
        # don't care about keylocs since they are irrelevant
        logg.debug("Ignoring key locations")
        keylocs = [None] * len(keys)

    if not isinstance(edge_reductions, (tuple, list)):
        edge_reductions = [edge_reductions] * len(keys)
    if not all(map(callable, edge_reductions)):
        raise ValueError("Not all `edge_reductions` functions are callable.")

    if not isinstance(labels, (tuple, list)):
        labels = [labels] * len(keys)
    elif not len(labels):
        labels = [None] * len(keys)
    elif not isinstance(labels[0], (tuple, list)):
        labels = [labels] * len(keys)

    if len(keys) != len(labels):
        raise ValueError(
            f"`Keys` and `labels` must be of the same shape, found `{len(keys)}` and `{len(labels)}`."
        )

    if isinstance(data, _AnnData):
        if graph_key is None:
            raise ValueError(
                "Argument `graph_key` cannot be `None` when `data` is `anndata.Anndata` object."
            )
        gdata = _read_graph_data(data, graph_key)
    elif isinstance(data, (np.ndarray, spmatrix)):
        gdata = data
    else:
        raise TypeError(
            f"Expected argument `data` to be one of `anndata.AnnData`, `numpy.ndarray`, `scipy.sparse.spmatrix`, "
            f"found `{type(data).__name__!r}`.")
    is_sparse = issparse(gdata)

    if ixs is not None:
        gdata = gdata[ixs, :][:, ixs]
    else:
        ixs = list(range(gdata.shape[0]))

    start = logg.info("Creating graph")
    G = (nx.from_scipy_sparse_matrix(gdata, create_using=nx.DiGraph)
         if is_sparse else nx.from_numpy_array(gdata, create_using=nx.DiGraph))

    remove_low_weight_edges()
    remove_top_n_edges()
    if edge_normalize:
        normalize_weights()
    logg.info("    Finish", time=start)

    # do NOT recreate the graph, for the edge reductions
    # gdata = nx.to_numpy_array(G)

    if figsize is None:
        figsize = (12, 8 * len(keys))

    fig, axes = plt.subplots(nrows=len(keys),
                             ncols=1,
                             figsize=figsize,
                             dpi=dpi)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])
    axes = np.ravel(axes)

    if isinstance(layout, str):
        if f"X_{layout}" not in data.obsm:
            raise KeyError(
                f"Unable to find embedding `'X_{layout}'` in `adata.obsm`.")
        components = layout_kwargs.get("components", [0, 1])
        if len(components) != 2:
            raise ValueError(
                f"Components in `layout_kwargs` must be of length `2`, found `{len(components)}`."
            )
        emb = data.obsm[f"X_{layout}"][:, components]
        pos = {i: emb[ix, :] for i, ix in enumerate(ixs)}
        logg.info(f"Embedding graph using `{layout!r}` layout")
    elif isinstance(layout, dict):
        rng = range(len(ixs))
        for k, v in layout.items():
            if k not in rng:
                raise ValueError(
                    f"Key in `layout` must be in `range(len(ixs))`, found `{k}`."
                )
            if len(v) != 2:
                raise ValueError(
                    f"Value in `layout` must be a `tuple` or a `list` of length 2, found `{len(v)}`."
                )
        pos = layout
        logg.debug("Using precomputed layout")
    elif callable(layout):
        start = logg.info(
            f"Embedding graph using `{layout.__name__!r}` layout")
        pos = layout(G, **layout_kwargs)
        logg.info("    Finish", time=start)
    else:
        raise TypeError(f"Argument `layout` must be either a `string`, "
                        f"a `dict` or a `callable`, found `{type(layout)}`.")

    curves, lc = None, None
    if edge_use_curved:
        try:
            from ._utils import _curved_edges

            logg.debug("Creating curved edges")
            curves = _curved_edges(G,
                                   pos,
                                   self_loop_radius_frac,
                                   polarity="directed")
            lc = LineCollection(
                curves,
                colors="black",
                linewidths=np.clip(
                    np.ravel([v["weight"]
                              for v in G.edges.values()]) * edge_weight_scale,
                    0,
                    edge_width_limit,
                ),
                alpha=edge_alpha,
            )
        except ImportError as e:
            global _msg_shown
            if not _msg_shown:
                print(
                    str(e)[:-1],
                    "in order to use curved edges or specify `edge_use_curved=False`.",
                )
                _msg_shown = True

    for ax, keyloc, key, labs, er in zip(axes, keylocs, keys, labels,
                                         edge_reductions):
        label_col = {}  # dummy value

        if key in ("incoming", "outgoing", "self_loops"):
            if key in ("incoming", "outgoing"):
                vals = er(gdata, axis=int(key == "outgoing"))
                if issparse(vals):
                    vals = vals.A
                vals = vals.flatten()
            else:
                vals = gdata.diagonal() if is_sparse else np.diag(gdata)
            node_v = dict(zip(pos.keys(), vals))
        else:
            label_col = getattr(data, keyloc)
            if key in label_col:
                node_v = dict(zip(pos.keys(), label_col[key]))
            else:
                raise RuntimeError(
                    f"Key `{key!r}` not found in `adata.{keyloc}`.")

        if labs is not None:
            if len(labs) != len(pos):
                raise RuntimeError(
                    f"Number of labels ({len(labels)}) and nodes ({len(pos)}) mismatch."
                )
            nx.draw_networkx_labels(
                G,
                pos,
                labels=labs if isinstance(labs, dict) else dict(
                    zip(pos.keys(), labs)),
                ax=ax,
                font_color=font_color,
                font_size=font_size,
            )

        if lc is not None and curves is not None:
            ax.add_collection(deepcopy(lc))  # copying necessary
            if show_arrows:
                plot_arrows(curves, G, pos, ax, edge_weight_scale)
        else:
            nx.draw_networkx_edges(
                G,
                pos,
                width=[
                    np.clip(
                        v["weight"] * edge_weight_scale,
                        _min_edge_weight,
                        edge_width_limit,
                    ) for _, v in G.edges.items()
                ],
                alpha=edge_alpha,
                edge_color="black",
                arrows=True,
                arrowstyle="-|>",
            )

        if key in label_col and is_categorical_dtype(label_col[key]):
            values = label_col[key]
            if keyloc in ("obs", "obsm"):
                values = values[ixs]
            categories = values.cat.categories
            color_key = _colors(key)
            if color_key in data.uns:
                mapper = dict(zip(categories, data.uns[color_key]))
            else:
                mapper = dict(
                    zip(categories, map(cat_cmap.get, range(len(categories)))))

            colors = []
            seen = set()

            for v in values:
                colors.append(mapper[v])
                seen.add(v)

            nodes_kwargs = dict(cmap=cat_cmap, node_color=colors)  # noqa
            if legend_loc is not None:
                x, y = pos[0]
                for label in sorted(seen):
                    ax.plot([x], [y], label=label, color=mapper[label])
                ax.legend(loc=legend_loc)
        else:
            values = list(node_v.values())
            vmin, vmax = np.min(values), np.max(values)
            nodes_kwargs = dict(  # noqa
                cmap=cont_cmap,
                node_color=values,
                vmin=vmin,
                vmax=vmax)

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="1.5%", pad=0.05)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          cmap=cont_cmap,
                                          norm=mpl.colors.Normalize(vmin=vmin,
                                                                    vmax=vmax))

        if color_nodes is False:
            nodes_kwargs = {}

        nx.draw_networkx_nodes(G,
                               pos,
                               node_size=node_size,
                               ax=ax,
                               **nodes_kwargs)

        ax.set_title(key)
        ax.axis("off")

    if save is not None:
        save_fig(fig, save)

    fig.show()
Beispiel #7
0
    def plot(
        self,
        figsize: Tuple[float, float] = (8, 5),
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Tuple[float, float] = None,
        abs_prob_cmap: mcolors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        lineage_color: str = "black",
        alpha: float = 0.8,
        lineage_alpha: float = 0.2,
        title: Optional[str] = None,
        size: int = 15,
        lw: float = 2,
        cbar: bool = True,
        margins: float = 0.015,
        xlabel: str = "pseudotime",
        ylabel: str = "expression",
        conf_int: bool = True,
        lineage_probability: bool = False,
        lineage_probability_conf_int: Union[bool, float] = False,
        lineage_probability_color: Optional[str] = None,
        dpi: int = None,
        fig: mpl.figure.Figure = None,
        ax: mpl.axes.Axes = None,
        return_fig: bool = False,
        save: Optional[str] = None,
        **kwargs,
    ) -> Optional[mpl.figure.Figure]:
        """
        Plot the smoothed gene expression.

        Parameters
        ----------
        figsize
            Size of the figure.
        same_plot
            Whether to plot all trends in the same plot.
        hide_cells
            Whether to hide the cells.
        perc
            Percentile by which to clip the absorption probabilities.
        abs_prob_cmap
            Colormap to use when coloring in the absorption probabilities.
        cell_color
            Color for the cells when not coloring absorption probabilities.
        lineage_color
            Color for the lineage.
        alpha
            Alpha channel for cells.
        lineage_alpha
            Alpha channel for lineage confidence intervals.
        title
            Title of the plot.
        size
            Size of the points.
        lw
            Line width for the smoothed values.
        cbar
            Whether to show colorbar.
        margins
            Margins around the plot.
        xlabel
            Label on the x-axis.
        ylabel
            Label on the y-axis.
        conf_int
            Whether to show the confidence interval.
        lineage_probability
            Whether to show smoothed lineage probability as a dashed line.
            Note that this will require 1 additional model fit.
        lineage_probability_conf_int
            Whether to compute and show smoothed lineage probability confidence interval.
            If :paramref:`self` is :class:`cellrank.ul.models.GAMR`, it can also specify the confidence level,
            the default is `0.95`. Only used when ``show_lineage_probability=True``.
        lineage_probability_color
            Color to use when plotting the smoothed ``lineage_probability``.
            If `None`, it's the same as ``lineage_color``. Only used when ``show_lineage_probability=True``.
        dpi
            Dots per inch.
        fig
            Figure to use, if `None`, create a new one.
        ax: :class:`matplotlib.axes.Axes`
            Ax to use, if `None`, create a new one.
        return_fig
            If `True`, return the figure object.
        save
            Filename where to save the plot. If `None`, just shows the plots.
        **kwargs
            Keyword arguments for :meth:`matplotlib.axes.Axes.legend`, e.g. to disable the legend, specify ``loc=None``.
            Only available when ``show_lineage_probability=True``.

        Returns
        -------
        %(just_plots)s
        """

        if self.y_test is None:
            raise RuntimeError("Run `.predict()` first.")

        if fig is None or ax is None:
            fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)

        if dpi is not None:
            fig.set_dpi(dpi)

        conf_int = conf_int and self.conf_int is not None
        hide_cells = (hide_cells or self.x_all is None or self.w_all is None
                      or self.y_all is None)

        lineage_probability_color = (lineage_color
                                     if lineage_probability_color is None else
                                     lineage_probability_color)

        scaler = kwargs.pop(
            "scaler",
            self._create_scaler(
                lineage_probability,
                show_conf_int=conf_int,
            ),
        )

        if lineage_probability:
            if ylabel in ("expression", self._gene):
                ylabel = f"scaled {ylabel}"

        vmin, vmax = None, None
        if not hide_cells:
            vmin, vmax = _minmax(self.w_all, perc)
            _ = ax.scatter(
                self.x_all.squeeze(),
                scaler(self.y_all.squeeze()),
                c=cell_color if same_plot or np.allclose(self.w_all, 1.0) else
                self.w_all.squeeze(),
                s=size,
                cmap=abs_prob_cmap,
                vmin=vmin,
                vmax=vmax,
                alpha=alpha,
            )

        if title is None:
            title = (f"{self._gene} @ {self._lineage}"
                     if self._lineage is not None else f"{self._gene}")

        ax.plot(self.x_test,
                scaler(self.y_test),
                color=lineage_color,
                lw=lw,
                label=title)

        if title is not None:
            ax.set_title(title)
        if ylabel is not None:
            ax.set_ylabel(ylabel)
        if xlabel is not None:
            ax.set_xlabel(xlabel)

        ax.margins(margins)

        if conf_int:
            ax.fill_between(
                self.x_test.squeeze(),
                scaler(self.conf_int[:, 0]),
                scaler(self.conf_int[:, 1]),
                alpha=lineage_alpha,
                color=lineage_color,
                linestyle="--",
            )

        if (lineage_probability and not isinstance(self, FittedModel)
                and not np.allclose(self.w, 1.0)):
            from cellrank.pl._utils import _is_any_gam_mgcv

            model = deepcopy(self)
            model._y = self._reshape_and_retype(self.w).copy()
            model = model.fit()

            if not lineage_probability_conf_int:
                y = model.predict()
            elif _is_any_gam_mgcv(model):
                y = model.predict(
                    level=lineage_probability_conf_int if isinstance(
                        lineage_probability_conf_int, float) else 0.95)
            else:
                y = model.predict()
                model.confidence_interval()

                ax.fill_between(
                    model.x_test.squeeze(),
                    model.conf_int[:, 0],
                    model.conf_int[:, 1],
                    alpha=lineage_alpha,
                    color=lineage_probability_color,
                    linestyle="--",
                )

            handle = ax.plot(
                model.x_test,
                y,
                color=lineage_probability_color,
                lw=lw,
                linestyle="--",
                zorder=-1,
                label="probability",
            )

            if kwargs.get("loc", "best") is not None:
                ax.legend(handles=handle, **kwargs)

        if (cbar and not hide_cells and not same_plot
                and not np.allclose(self.w_all, 1.0)):
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="2%", pad=0.1)
            _ = mpl.colorbar.ColorbarBase(
                cax,
                norm=norm,
                cmap=abs_prob_cmap,
                ticks=np.linspace(norm.vmin, norm.vmax, 5),
            )

        if save is not None:
            save_fig(fig, save)

        if return_fig:
            return fig
Beispiel #8
0
def heatmap(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    mode: str = HeatmapMode.LINEAGES.s,
    time_key: str = "latent_time",
    time_range: Optional[Union[_time_range_type,
                               List[_time_range_type]]] = None,
    callback: _callback_type = None,
    cluster_key: Optional[Union[str, Sequence[str]]] = None,
    show_absorption_probabilities: bool = False,
    cluster_genes: bool = False,
    keep_gene_order: bool = False,
    scale: bool = True,
    n_convolve: Optional[int] = 5,
    show_all_genes: bool = False,
    cbar: bool = True,
    lineage_height: float = 0.33,
    fontsize: Optional[float] = None,
    xlabel: Optional[str] = None,
    cmap: mcolors.ListedColormap = cm.viridis,
    dendrogram: bool = True,
    return_genes: bool = False,
    return_models: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> Optional[Union[Dict[str, pd.DataFrame], Tuple[_return_model_type, Dict[
        str, pd.DataFrame]]]]:
    """
    Plot a heatmap of smoothed gene expression along specified lineages.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages for which to plot. If `None`, plot all lineages.
    %(backward)s
    mode
        Valid options are:

            - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``.
            - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    cluster_key
        Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap.
        Only available when ``mode={m.LINEAGES.s!r}``.
    show_absorption_probabilities
        Whether to also plot absorption probabilities alongside the smoothed expression.
        Only available when ``mode={m.LINEAGES.s!r}``.
    cluster_genes
        Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``.
    keep_gene_order
        Whether to keep the gene order for later lineages after the first was sorted.
        Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``.
    scale
        Whether to normalize the gene expression `0-1` range.
    n_convolve
        Size of the convolution window when smoothing absorption probabilities.
    show_all_genes
        Whether to show all genes on y-axis.
    cbar
        Whether to show the colorbar.
    lineage_height
        Height of a bar when ``mode={m.GENES.s!r}``.
    fontsize
        Size of the title's font.
    xlabel
        Label on the x-axis. If `None`, it is determined based on ``time_key``.
    cmap
        Colormap to use when visualizing the smoothed expression.
    dendrogram
        Whether to show dendrogram when ``cluster_genes=True``.
    return_genes
        Whether to return the sorted or clustered genes. Only available when ``mode={m.LINEAGES.s!r}``.
    %(return_models)s
    %(parallel)s
    %(plotting)s
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s
    :class:`pandas.DataFrame`
        If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame`
        containing the clustered or sorted genes.
    """

    import seaborn as sns

    def find_indices(series: pd.Series, values) -> Tuple[Any]:
        def find_nearest(array: np.ndarray, value: float) -> int:
            ix = np.searchsorted(array, value, side="left")
            if ix > 0 and (ix == len(array) or fabs(value - array[ix - 1]) <
                           fabs(value - array[ix])):
                return ix - 1
            return ix

        series = series[np.argsort(series.values)]

        return tuple(series[[find_nearest(series.values, v)
                             for v in values]].index)

    def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray:
        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        lin = adata[ixs, :].obsm[lineage_key][lname]

        lin = lin.X.copy().squeeze()
        if n_convolve is not None:
            lin = convolve(lin,
                           np.ones(n_convolve) / n_convolve,
                           mode="nearest")

        return lin

    def create_col_colors(lname: str,
                          rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]:
        color = adata.obsm[lineage_key][lname].colors[0]
        lin = subset_lineage(lname, rng)

        h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
        end_color = mcolors.hsv_to_rgb([h, 1, v])

        lineage_cmap = mcolors.LinearSegmentedColormap.from_list(
            "lineage_cmap", ["#ffffff", end_color], N=len(rng))
        norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin))
        scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm)

        return (
            np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]),
            lineage_cmap,
            norm,
        )

    def create_col_categorical_color(cluster_key: str,
                                     rng: np.ndarray) -> np.ndarray:
        if not is_categorical_dtype(adata.obs[cluster_key]):
            raise TypeError(
                f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                f"found `{adata.obs[cluster_key].dtype.name!r}`.")

        color_key = f"{cluster_key}_colors"
        if color_key not in adata.uns:
            logg.warning(
                f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors"
            )
            colors = _create_categorical_colors(
                len(adata.obs[cluster_key].cat.categories))
            adata.uns[color_key] = colors
        else:
            colors = adata.uns[color_key]

        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))

        return np.array([
            mcolors.to_hex(mapper[v])
            for v in adata[ixs, :].obs[cluster_key].values
        ])

    def create_cbar(
        ax,
        x_delta: float,
        cmap: Cmap,
        norm: Norm,
        label: Optional[str] = None,
    ) -> Ax:
        cax = inset_axes(
            ax,
            width="1%",
            height="100%",
            loc="lower right",
            bbox_to_anchor=(x_delta, 0, 1, 1),
            bbox_transform=ax.transAxes,
        )

        _ = mpl.colorbar.ColorbarBase(
            cax,
            cmap=cmap,
            norm=norm,
            label=label,
            ticks=np.linspace(norm.vmin, norm.vmax, 5),
        )

        return cax

    @valuedispatch
    def _plot_heatmap(_mode: HeatmapMode) -> Fig:
        pass

    @_plot_heatmap.register(HeatmapMode.GENES)
    def _() -> Tuple[Fig, None]:
        def color_fill_rec(ax,
                           xs,
                           y1,
                           y2,
                           colors=None,
                           cmap=cmap,
                           **kwargs) -> None:
            colors = colors if cmap is None else cmap(colors)

            x = 0
            for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)):
                dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2])
                ax.add_patch(
                    plt.Rectangle((x, y1),
                                  dx,
                                  y2 - y1,
                                  color=color,
                                  ec=color,
                                  **kwargs))

            ax.plot(x, y2, lw=0)

        fig, axes = plt.subplots(
            nrows=len(genes) + show_absorption_probabilities,
            figsize=(12, len(genes) + len(lineages) * lineage_height)
            if figsize is None else figsize,
            dpi=dpi,
            constrained_layout=True,
        )

        if not isinstance(axes, Iterable):
            axes = [axes]
        axes = np.ravel(axes)

        if show_absorption_probabilities:
            data["absorption probability"] = data[next(iter(data.keys()))]

        for ax, (gene, models) in zip(axes, data.items()):
            if scale:
                vmin, vmax = 0, 1
            else:
                c = np.array([m.y_test for m in models.values()])
                vmin, vmax = np.nanmin(c), np.nanmax(c)

            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

            ix = 0
            ys = [ix]

            if gene == "absorption probability":
                norm = mcolors.Normalize(vmin=0, vmax=1)
                for ln, x in ((ln, m.x_test) for ln, m in models.items()):
                    y = np.ones_like(x)
                    c = subset_lineage(ln, x.squeeze())

                    color_fill_rec(ax,
                                   x,
                                   y * ix,
                                   y * (ix + lineage_height),
                                   colors=norm(c))

                    ix += lineage_height
                    ys.append(ix)
            else:
                for x, c in ((m.x_test, m.y_test) for m in models.values()):
                    y = np.ones_like(x)
                    c = _min_max_scale(c) if scale else c

                    color_fill_rec(ax,
                                   x,
                                   y * ix,
                                   y * (ix + lineage_height),
                                   colors=norm(c))

                    ix += lineage_height
                    ys.append(ix)

            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.min(xs), np.max(xs)
            ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS))

            ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2)
            ax.spines["left"].set_position(
                ("data", 0)
            )  # move the left spine to the rectangles to get nicer yticks
            ax.set_yticklabels(models.keys(), ha="right")

            ax.set_title(gene, fontdict={"fontsize": fontsize})
            ax.set_ylabel("lineage")

            for pos in ["top", "bottom", "left", "right"]:
                ax.spines[pos].set_visible(False)

            if cbar:
                cax, _ = mpl.colorbar.make_axes(ax)
                _ = mpl.colorbar.ColorbarBase(
                    cax,
                    ticks=np.linspace(vmin, vmax, 5),
                    norm=norm,
                    cmap=cmap,
                    label="value" if gene == "absorption probability" else
                    "scaled expression" if scale else "expression",
                )

            ax.tick_params(
                top=False,
                bottom=False,
                left=True,
                right=False,
                labelleft=True,
                labelbottom=False,
            )

        ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f"))
        ax.tick_params(
            top=False,
            bottom=True,
            left=True,
            right=False,
            labelleft=True,
            labelbottom=True,
        )
        ax.set_xlabel(xlabel)

        return fig, None

    @_plot_heatmap.register(HeatmapMode.LINEAGES)
    def _() -> Tuple[List[Fig], pd.DataFrame]:
        data_t = defaultdict(dict)  # transpose
        for gene, lns in data.items():
            for ln, y in lns.items():
                data_t[ln][gene] = y

        figs = []
        gene_order = None
        sorted_genes = pd.DataFrame() if return_genes else None

        for lname, models in data_t.items():
            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.nanmin(xs), np.nanmax(xs)

            df = pd.DataFrame([m.y_test for m in models.values()],
                              index=models.keys())
            df.index.name = "genes"

            if not cluster_genes:
                if gene_order is not None:
                    df = df.loc[gene_order]
                else:
                    max_sort = np.argsort(
                        np.argmax(df.apply(_min_max_scale, axis=1).values,
                                  axis=1))
                    df = df.iloc[max_sort, :]
                    if keep_gene_order:
                        gene_order = df.index

            cat_colors = None
            if cluster_key is not None:
                cat_colors = np.stack(
                    [
                        create_col_categorical_color(
                            c, np.linspace(x_min, x_max, df.shape[1]))
                        for c in cluster_key
                    ],
                    axis=0,
                )

            if show_absorption_probabilities:
                col_colors, col_cmap, col_norm = create_col_colors(
                    lname, np.linspace(x_min, x_max, df.shape[1]))
                if cat_colors is not None:
                    col_colors = np.vstack([cat_colors, col_colors[None, :]])
            else:
                col_colors, col_cmap, col_norm = cat_colors, None, None

            row_cluster = cluster_genes and df.shape[0] > 1
            show_clust = row_cluster and dendrogram

            g = sns.clustermap(
                df,
                cmap=cmap,
                figsize=(10, min(len(genes) / 8 +
                                 1, 10)) if figsize is None else figsize,
                xticklabels=False,
                row_cluster=row_cluster,
                col_colors=col_colors,
                colors_ratio=0,
                col_cluster=False,
                cbar_pos=None,
                yticklabels=show_all_genes or "auto",
                standard_scale=0 if scale else None,
            )

            if cbar:
                cax = create_cbar(
                    g.ax_heatmap,
                    0.1,
                    cmap=cmap,
                    norm=mcolors.Normalize(
                        vmin=0 if scale else np.min(df.values),
                        vmax=1 if scale else np.max(df.values),
                    ),
                    label="scaled expression" if scale else "expression",
                )
                g.fig.add_axes(cax)

                if col_cmap is not None and col_norm is not None:
                    cax = create_cbar(
                        g.ax_heatmap,
                        0.25,
                        cmap=col_cmap,
                        norm=col_norm,
                        label="absorption probability",
                    )
                    g.fig.add_axes(cax)

            if g.ax_col_colors:
                main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap)
                n_bars = show_absorption_probabilities + (
                    len(cluster_key) if cluster_key is not None else 0)
                _set_ax_height_to_cm(
                    g.fig,
                    g.ax_col_colors,
                    height=min(
                        5,
                        max(n_bars * main_bbox.height / len(df),
                            0.25 * n_bars)),
                )
                g.ax_col_colors.set_title(lname,
                                          fontdict={"fontsize": fontsize})
            else:
                g.ax_heatmap.set_title(lname, fontdict={"fontsize": fontsize})

            g.ax_col_dendrogram.set_visible(
                False)  # gets rid of top free space

            g.ax_heatmap.yaxis.tick_left()
            g.ax_heatmap.yaxis.set_label_position("right")

            g.ax_heatmap.set_xlabel(xlabel)
            g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS))
            g.ax_heatmap.set_xticklabels(
                list(
                    map(lambda n: round(n, 3),
                        np.linspace(x_min, x_max, _N_XTICKS))))

            if show_clust:
                # robustly show dendrogram, because gene names can be long
                g.ax_row_dendrogram.set_visible(True)
                dendro_box = g.ax_row_dendrogram.get_position()

                pad = 0.005
                bb = g.ax_heatmap.yaxis.get_tightbbox(
                    g.fig.canvas.get_renderer()).transformed(
                        g.fig.transFigure.inverted())

                dendro_box.x0 = bb.x0 - dendro_box.width - pad
                dendro_box.x1 = bb.x0 - pad

                g.ax_row_dendrogram.set_position(dendro_box)
            else:
                g.ax_row_dendrogram.set_visible(False)

            if return_genes:
                sorted_genes[lname] = (df.index[g.dendrogram_row.reordered_ind]
                                       if hasattr(g, "dendrogram_row")
                                       and g.dendrogram_row is not None else
                                       df.index)

            figs.append(g)

        return figs, sorted_genes

    mode = HeatmapMode(mode)

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[lineage_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    lineages = _unique_order_preserving(lineages)

    _ = adata.obsm[lineage_key][lineages]

    if cluster_key is not None:
        if isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        cluster_key = _unique_order_preserving(cluster_key)

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    _check_collection(adata,
                      genes,
                      "var_names",
                      use_raw=kwargs.get("use_raw", False))

    kwargs["backward"] = backward
    kwargs["time_key"] = time_key
    models = _create_models(model, genes, lineages)
    all_models, data, genes, lineages = _fit_bulk(
        models,
        _create_callbacks(adata, callback, genes, lineages, **kwargs),
        genes,
        lineages,
        time_range,
        return_models=True,  # always return (better error messages)
        filter_all_failed=True,
        parallel_kwargs={
            "show_progress_bar": show_progress_bar,
            "n_jobs": _get_n_cores(n_jobs, len(genes)),
            "backend": _get_backend(models, backend),
        },
        **kwargs,
    )

    xlabel = time_key if xlabel is None else xlabel

    logg.debug(f"Plotting `{mode.s!r}` heatmap")
    fig, genes = _plot_heatmap(mode)

    if save is not None and fig is not None:
        if not isinstance(fig, Iterable):
            save_fig(fig, save)
        elif len(fig) == 1:
            save_fig(fig[0], save)
        else:
            for ln, f in zip(lineages, fig):
                save_fig(f, os.path.join(save, f"lineage_{ln}"))

    if return_genes and mode == HeatmapMode.LINEAGES:
        return (all_models, genes) if return_models else genes
    elif return_models:
        return all_models
def cluster_lineage(
        adata: AnnData,
        model: _model_type,
        genes: Sequence[str],
        lineage: str,
        backward: bool = False,
        time_range: _time_range_type = None,
        clusters: Optional[Sequence[str]] = None,
        n_points: int = 200,
        time_key: str = "latent_time",
        cluster_key: str = "clusters",
        norm: bool = True,
        recompute: bool = False,
        callback: _callback_type = None,
        ncols: int = 3,
        sharey: Union[str, bool] = False,
        key_added: Optional[str] = None,
        show_progress_bar: bool = True,
        n_jobs: Optional[int] = 1,
        backend: str = _DEFAULT_BACKEND,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
        neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
        louvain_kwargs: Dict = MappingProxyType({}),
        **kwargs,
) -> None:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential lineage
    drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered.
        Useful when plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    cluster_key
        Key in ``adata.obs`` where the clustering is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key_added
        Postfix to add when saving the results to ``adata.uns``.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    louvain_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain`.
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s

        Updates ``adata.uns`` with the following key:

            - ``lineage_{lineage}_trend_{key_added}`` - an :class:`anndata.AnnData` object
              of shape ``(n_genes, n_points)`` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    key_to_add = f"lineage_{lineage}_trend"
    if key_added is not None:
        logg.debug(f"Adding key `{key_added!r}`")
        key_to_add += f"_{key_added}"

    if recompute or key_to_add not in adata.uns:
        kwargs["time_key"] = time_key  # kwargs for the model.prepare
        kwargs["n_test_points"] = n_points
        kwargs["backward"] = backward

        models = _create_models(model, genes, [lineage])
        callbacks = _create_callbacks(adata, callback, genes, [lineage],
                                      **kwargs)

        backend = _get_backend(model, backend)
        n_jobs = _get_n_cores(n_jobs, len(genes))

        start = logg.info(f"Computing gene trends using `{n_jobs}` core(s)")
        trends = parallelize(
            _cluster_lineages_helper,
            genes,
            as_array=True,
            unit="gene",
            n_jobs=n_jobs,
            backend=backend,
            extractor=np.vstack,
            show_progress_bar=show_progress_bar,
        )(models, callbacks, lineage, time_range, **kwargs)
        logg.info("    Finish", time=start)

        trends = trends.T
        if norm:
            logg.debug("Normalizing using `StandardScaler`")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if n_points is not None and trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        pca_kwargs = dict(pca_kwargs)
        n_comps = pca_kwargs.pop(
            "n_comps",
            min(50, kwargs.get("n_test_points"), len(genes)) -
            1)  # default value

        sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        louvain_kwargs = dict(louvain_kwargs)
        louvain_kwargs["key_added"] = cluster_key
        sc.tl.louvain(trends, **louvain_kwargs)

        adata.uns[key_to_add] = trends
    else:
        logg.info(f"Loading data from `adata.uns[{key_to_add!r}]`")
        trends = adata.uns[key_to_add]

    if clusters is None:
        if cluster_key not in trends.obs:
            raise KeyError(f"Invalid cluster key `{cluster_key!r}`.")
        clusters = trends.obs[cluster_key].cat.categories

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs[cluster_key] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)
Beispiel #10
0
    def plot_single_flow(
        self,
        cluster: str,
        cluster_key: str,
        time_key: str,
        clusters: Optional[Sequence[Any]] = None,
        time_points: Optional[Sequence[Union[int, float]]] = None,
        min_flow: float = 0,
        remove_empty_clusters: bool = True,
        ascending: Optional[bool] = False,
        legend_loc: Optional[str] = "upper right out",
        alpha: Optional[float] = 0.8,
        xticks_step_size: Optional[int] = 1,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        show: bool = True,
    ) -> Optional[plt.Axes]:
        """
        Visualize outgoing flow from a cluster of cells :cite:`mittnenzweig:21`.

        Parameters
        ----------
        cluster
            Cluster for which to visualize outgoing compute_flow.
        cluster_key
            Key in :attr:`adata` ``.obs`` where clustering is stored.
        time_key
            Key in :attr:`adata` ``.obs`` where experimental time is stored.
        clusters
            Visualize flow only for these clusters. If `None`, use all clusters.
        time_points
            Visualize flow only for these time points. If `None`, use all time points.
        %(flow.parameters)s
        %(plotting)s
        show
            If `False`, return :class:`matplotlib.pyplot.Axes`.

        Returns
        -------
        :class:`matplotlib.pyplot.Axes`
            The axis object if ``show=False``.
        %(just_plots)s

        Notes
        -----
        This function is a Python reimplementation of the following
        `original R function <https://github.com/tanaylab/embflow/blob/main/scripts/generate_paper_figures/plot_vein.r>`_
        with some minor stylistic differences.
        This function will not recreate the results from :cite:`mittnenzweig:21`, because there the Metacell model
        :cite:`baran:19` was used to compute the flow, whereas here the transition matrix is used.
        """  # noqa: E501
        if self._transition_matrix is None:
            raise RuntimeError(
                "Compute transition matrix first as `.compute_transition_matrix()`."
            )

        fp = FlowPlotter(self.adata, self.transition_matrix, cluster_key,
                         time_key)
        fp = fp.prepare(cluster, clusters, time_points)

        ax = fp.plot(
            min_flow=min_flow,
            remove_empty_clusters=remove_empty_clusters,
            ascending=ascending,
            alpha=alpha,
            xticks_step_size=xticks_step_size,
            legend_loc=legend_loc,
            figsize=figsize,
            dpi=dpi,
        )

        if save is not None:
            save_fig(ax.figure, save)

        if not show:
            return ax
Beispiel #11
0
    def plot_random_walks(
        self,
        n_sims: int,
        max_iter: Union[int, float] = 0.25,
        seed: Optional[int] = None,
        successive_hits: int = 0,
        start_ixs: Indices_t = None,
        stop_ixs: Indices_t = None,
        basis: str = "umap",
        cmap: Union[str, LinearSegmentedColormap] = "gnuplot",
        linewidth: float = 1.0,
        linealpha: float = 0.3,
        ixs_legend_loc: Optional[str] = None,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        **kwargs: Any,
    ) -> None:
        """
        Plot random walks in an embedding.

        This method simulates random walks on the Markov chain defined though the corresponding transition matrix. The
        method is intended to give qualitative rather than quantitative insights into the transition matrix. Random
        walks are simulated by iteratively choosing the next cell based on the current cell's transition probabilities.

        Parameters
        ----------
        n_sims
            Number of random walks to simulate.
        %(rw_sim.parameters)s
        start_ixs
            Cells from which to sample the starting points. If `None`, use all cells.
            %(rw_ixs)s
            For example ``{'clusters': ['Ngn3 low EP', 'Ngn3 high EP']}`` means that starting points for random walks
            will be samples uniformly from the these clusters.
        stop_ixs
            Cells which when hit, the random walk is terminated. If `None`, terminate after ``max_iters``.
            %(rw_ixs)s
            For example ``{'clusters': ['Alpha', 'Beta']}`` and ``succesive_hits=3`` means that the random walk will
            stop prematurely after cells in the above specified clusters have been visited successively 3 times in a
            row.
        basis
            Basis in :attr:`anndata.AnnData.obsm` to use as an embedding.
        cmap
            Colormap for the random walk lines.
        linewidth
            Width of the random walk lines.
        linealpha
            Alpha value of the random walk lines.
        ixs_legend_loc
            Legend location for the start/top indices.
        %(parallel)s
        %(plotting)s
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        For each random walk, the first/last cell is marked by the start/end colors of ``cmap``.
        """
        def create_ixs(ixs: Indices_t, *, kind: str) -> Optional[np.ndarray]:
            if ixs is None:
                return None
            if isinstance(ixs, dict):
                # fmt: off
                if len(ixs) != 1:
                    raise ValueError(
                        f"Expected to find only 1 cluster key, found `{len(ixs)}`."
                    )
                cluster_key = next(iter(ixs.keys()))
                if cluster_key not in self.adata.obs:
                    raise KeyError(
                        f"Unable to find `adata.obs[{cluster_key!r}]`.")
                if not is_categorical_dtype(self.adata.obs[cluster_key]):
                    raise TypeError(
                        f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                        f"found `{infer_dtype(self.adata.obs[cluster_key])}`.")
                ixs = np.where(
                    np.isin(self.adata.obs[cluster_key], ixs[cluster_key]))[0]
                # fmt: on
            elif isinstance(ixs, str):
                ixs = np.where(self.adata.obs_names == ixs)[0]
            else:
                ixs = np.where(np.isin(self.adata.obs_names, ixs))[0]

            if not len(ixs):
                logg.warning(
                    f"No {kind} indices have been selected, using `None`")
                return None

            return ixs

        if self._transition_matrix is None:
            raise RuntimeError(
                "Compute transition matrix first as `.compute_transition_matrix()`."
            )
        emb = _get_basis(self.adata, basis)

        if isinstance(cmap, str):
            cmap = plt.get_cmap(cmap)
        if not isinstance(cmap, LinearSegmentedColormap):
            if not hasattr(cmap, "colors"):
                raise AttributeError(
                    "Unable to create a colormap, `cmap` does not have attribute `colors`."
                )
            cmap = LinearSegmentedColormap.from_list("random_walk",
                                                     colors=cmap.colors,
                                                     N=max_iter)

        start_ixs = create_ixs(start_ixs, kind="start")
        stop_ixs = create_ixs(stop_ixs, kind="stop")
        rw = RandomWalk(self.transition_matrix,
                        start_ixs=start_ixs,
                        stop_ixs=stop_ixs)
        sims = rw.simulate_many(
            n_sims=n_sims,
            max_iter=max_iter,
            seed=seed,
            n_jobs=n_jobs,
            backend=backend,
            successive_hits=successive_hits,
            show_progress_bar=show_progress_bar,
        )

        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        scv.pl.scatter(self.adata, basis=basis, show=False, ax=ax, **kwargs)

        logg.info("Plotting random walks")
        for sim in sims:
            x = emb[sim][:, 0]
            y = emb[sim][:, 1]
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            n_seg = len(segments)

            lc = LineCollection(
                segments,
                linewidths=linewidth,
                colors=[cmap(float(i) / n_seg) for i in range(n_seg)],
                alpha=linealpha,
                zorder=2,
            )
            ax.add_collection(lc)

        for ix in [0, -1]:
            ixs = [sim[ix] for sim in sims]
            plot_outline(
                x=emb[ixs][:, 0],
                y=emb[ixs][:, 1],
                outline_color=("black", to_hex(cmap(float(abs(ix))))),
                kwargs={
                    "s": kwargs.get("s", default_size(self.adata)) * 1.1,
                    "alpha": 0.9,
                },
                ax=ax,
                zorder=4,
            )

        if ixs_legend_loc not in (None, "none"):
            from cellrank.pl._utils import _position_legend

            h1 = ax.scatter([], [], color=cmap(0.0), label="start")
            h2 = ax.scatter([], [], color=cmap(1.0), label="stop")
            legend = ax.get_legend()
            if legend is not None:
                ax.add_artist(legend)
            _position_legend(ax, legend_loc=ixs_legend_loc, handles=[h1, h2])

        if save is not None:
            save_fig(fig, save)
def circular_projection(
    adata: AnnData,
    keys: Union[str, Sequence[str]],
    backward: bool = False,
    lineages: Optional[Union[str, Sequence[str]]] = None,
    early_cells: Optional[Union[Mapping[str, Sequence[str]],
                                Sequence[str]]] = None,
    lineage_order: Optional[Literal["default", "optimal"]] = None,
    metric: Union[str, Callable, np.ndarray, pd.DataFrame] = "correlation",
    normalize_by_mean: bool = True,
    ncols: int = 4,
    space: float = 0.25,
    use_raw: bool = False,
    text_kwargs: Mapping[str, Any] = MappingProxyType({}),
    labeldistance: float = 1.25,
    labelrot: Union[Literal["default", "best"], float] = "best",
    show_edges: bool = True,
    key_added: Optional[str] = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
):
    r"""
    Plot absorption probabilities on a circular embedding as done in [Velten17]_.

    Parameters
    ----------
    %(adata)s
    keys
        Keys in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names`. Additional keys are:

            - `'kl_divergence'` - as in [Velten17]_, computes KL-divergence between the fate probabilities of a cell
              and the average fate probabilities. See ``early_cells`` for more information.
            - `'entropy'` - as in [Setty19]_, computes entropy over a cells fate probabilities.

    %(backward)s
    lineages
        Lineages to plot. If `None`, plot all lineages.
    early_cells
        Cell ids or a mask marking early cells used to define the average fate probabilities. If `None`, use all cells.
        Only used when `'kl_divergence'` is in ``keys``. If a :class:`dict`, key specifies a cluster key in
        :attr:`anndata.AnnData.obs` and the values specify cluster labels containing early cells.
    lineage_order
        Can be one of the following:

            - `None` - it will determined automatically, based on the number of lineages.
            - `'optimal'` - order the lineages optimally by solving the Travelling salesman problem (TSP).
              Recommended for <= `20` lineages.
            - `'default'` - use the order as specified in ``lineages``.

    metric
        Metric to use when constructing pairwise distance matrix when ``lineage_order = 'optimal'``. For available
        options, see :func:`sklearn.metrics.pairwise_distances`.
    normalize_by_mean
        If `True`, normalize each lineage by its mean probability, as done in [Velten17]_.
    ncols
        Number of columns when plotting multiple ``keys``.
    space
        Horizontal and vertical space between for :func:`matplotlib.pyplot.subplots_adjust`.
    use_raw
        Whether to access :attr:`anndata.AnnData.raw` when there are ``keys`` in :attr:`anndata.AnnData.var_names`.
    text_kwargs
        Keyword arguments for :func:`matplotlib.pyplot.text`.
    labeldistance
        Distance at which the lineage labels will be drawn.
    labelrot
        How to rotate the labels. Valid options are:

            - `'best'` - rotate labels so that they are easily readable.
            - `'default'` - use :mod:`matplotlib`'s default.
            - `None` - same as `'default'`.

        If a :class:`float`, all labels will be rotated by this many degrees.
    show_edges
        Whether to show the edges surrounding the simplex.
    key_added
        Key in :attr:`anndata.AnnData.obsm` where to add the circular embedding. If `None`, it will be set to
        `'X_fate_simplex_{fwd,bwd}'`, based on ``backward``.
    %(plotting)s
    kwargs
        Keyword arguments for :func:`scvelo.pl.scatter`.

    Returns
    -------
    %(just_plots)s
        Also updates ``adata`` with the following fields:

            - :attr:`anndata.AnnData.obsm` ``['{key_added}']``: the circular projection.
            - :attr:`anndata.AnnData.obs` ``['to_{initial,terminal}_states_{method}']``: the priming degree,
              if a method is present in ``keys``.
    """
    if labeldistance is not None and labeldistance < 0:
        raise ValueError(
            f"Expected `delta` to be positive, found `{labeldistance}`.")

    if labelrot is None:
        labelrot = LabelRot.DEFAULT
    if isinstance(labelrot, str):
        labelrot = LabelRot(labelrot)

    suffix = "bwd" if backward else "fwd"
    if key_added is None:
        key_added = "X_fate_simplex_" + suffix

    if isinstance(keys, str):
        keys = (keys, )

    keys = _unique_order_preserving(keys)
    keys_ = _check_collection(
        adata, keys, "obs", key_name="Observation",
        raise_exc=False) + _check_collection(adata,
                                             keys,
                                             "var_names",
                                             key_name="Gene",
                                             raise_exc=False,
                                             use_raw=use_raw)
    haystack = {s.s for s in PrimingDegree}
    keys = keys_ + [k for k in keys if k in haystack]
    keys = _unique_order_preserving(keys)

    if not len(keys):
        raise ValueError("No valid keys have been selected.")

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    probs = adata.obsm[lineage_key]

    if isinstance(lineages, str):
        lineages = (lineages, )
    elif lineages is None:
        lineages = probs.names

    probs: Lineage = adata.obsm[lineage_key][lineages]
    n_lin = probs.shape[1]
    if n_lin <= 2:
        raise ValueError(f"Expected at least `3` lineages, found `{n_lin}`")

    X = probs.X.copy()
    if normalize_by_mean:
        X /= np.mean(X, axis=0)[None, :]
        X /= X.sum(1)[:, None]
        # this happens when cells for sel. lineages sum to 1 (or when the lineage average is 0, which is unlikely)
        X = np.nan_to_num(X, nan=1.0 / n_lin, copy=False)

    if lineage_order is None:
        lineage_order = LineageOrder.OPTIMAL if n_lin <= 15 else LineageOrder.DEFAULT
        logg.debug(f"Set ordering to `{lineage_order}`")
    lineage_order = LineageOrder(lineage_order)

    if lineage_order == LineageOrder.OPTIMAL:
        logg.info(f"Solving TSP for `{n_lin}` states")
        _, order = _get_optimal_order(X, metric=metric)
    else:
        order = np.arange(n_lin)

    probs = probs[:, order]
    X = X[:, order]

    angle_vec = np.linspace(0, 2 * np.pi, n_lin, endpoint=False)
    angle_vec_sin = np.cos(angle_vec)
    angle_vec_cos = np.sin(angle_vec)

    x = np.sum(X * angle_vec_sin, axis=1)
    y = np.sum(X * angle_vec_cos, axis=1)
    adata.obsm[key_added] = np.c_[x, y]

    nrows = int(np.ceil(len(keys) / ncols))
    fig, ax = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(ncols * 5, nrows * 5) if figsize is None else figsize,
        dpi=dpi,
    )

    fig.subplots_adjust(wspace=space, hspace=space)
    axes = np.ravel([ax])

    text_kwargs = dict(text_kwargs)
    text_kwargs["ha"] = "center"
    text_kwargs["va"] = "center"

    _i = 0
    for _i, (k, ax) in enumerate(zip(keys, axes)):

        set_lognorm, colorbar = False, kwargs.pop("colorbar", True)
        try:
            _ = PrimingDegree(k)
            logg.debug(f"Calculating priming degree using `method={k}`")
            val = probs.priming_degree(method=k, early_cells=early_cells)
            k = f"{lineage_key}_{k}"
            adata.obs[k] = val
        except ValueError:
            pass

        scv.pl.scatter(
            adata,
            basis=key_added,
            color=k,
            show=False,
            ax=ax,
            use_raw=use_raw,
            norm=LogNorm() if set_lognorm else None,
            colorbar=colorbar,
            **kwargs,
        )
        if colorbar and set_lognorm:
            cbar = ax.collections[0].colorbar
            cax = cbar.locator.axis
            ticks = cax.minor.locator.tick_values(cbar.vmin, cbar.vmax)
            ticks = [ticks[0], ticks[len(ticks) // 2 + 1], ticks[-1]]
            cbar.set_ticks(ticks)
            cbar.set_ticklabels([f"{t:.2f}" for t in ticks])
            cbar.update_ticks()

        patches, texts = ax.pie(
            np.ones_like(angle_vec),
            labeldistance=labeldistance,
            rotatelabels=True,
            labels=probs.names[::-1],
            startangle=-360 / len(angle_vec) / 2,
            counterclock=False,
            textprops=text_kwargs,
        )

        for patch in patches:
            patch.set_visible(False)

        # clockwise
        for color, text in zip(probs.colors[::-1], texts):
            if isinstance(labelrot, (int, float)):
                text.set_rotation(labelrot)
            elif labelrot == LabelRot.BEST:
                rot = text.get_rotation()
                text.set_rotation(rot + 90 + (1 - rot // 180) * 180)
            elif labelrot != LabelRot.DEFAULT:
                raise NotImplementedError(
                    f"Label rotation `{labelrot}` is not yet implemented.")
            text.set_color(color)

        if not show_edges:
            continue

        for i, color in enumerate(probs.colors):
            next = (i + 1) % n_lin
            x = 1.04 * np.linspace(angle_vec_sin[i], angle_vec_sin[next], _N)
            y = 1.04 * np.linspace(angle_vec_cos[i], angle_vec_cos[next], _N)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)

            cmap = LinearSegmentedColormap.from_list(
                "abs_prob_cmap", [color, probs.colors[next]], N=_N)
            lc = LineCollection(segments, cmap=cmap, zorder=-1)
            lc.set_array(np.linspace(0, 1, _N))
            lc.set_linewidth(2)
            ax.add_collection(lc)

    for j in range(_i + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)
    def plot_lineage_drivers(
        self,
        lineage: str,
        n_genes: int = 8,
        ncols: Optional[int] = None,
        use_raw: bool = False,
        title_fmt: str = "{gene} qval={qval:.4e}",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        **kwargs,
    ) -> None:
        """
        Plot lineage drivers discovered by :meth:`compute_lineage_drivers`.

        Parameters
        ----------
        lineage
            Lineage for which to plot the driver genes.
        n_genes
            Top most correlated genes to plot.
        ncols
            Number of columns.
        use_raw
            Whether to look in :paramref:`adata` ``.raw.var`` or :paramref:`adata` ``.var``.
        title_fmt
            Title format. Possible keywords include `{gene}`, `{qval}`, `{corr}` for gene name,
            q-value and correlation, respectively.
        %(plotting)s
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        def prepare_format(
            gene: str,
            *,
            pval: Optional[float],
            qval: Optional[float],
            corr: Optional[float],
        ) -> Dict[str, Any]:
            kwargs = {}
            if "{gene" in title_fmt:
                kwargs["gene"] = gene
            if "{pval" in title_fmt:
                kwargs["pval"] = float(pval) if pval is not None else np.nan
            if "{qval" in title_fmt:
                kwargs["qval"] = float(qval) if qval is not None else np.nan
            if "{corr" in title_fmt:
                kwargs["corr"] = float(corr) if corr is not None else np.nan

            return kwargs

        lin_drivers = self._get(P.LIN_DRIVERS)

        if lin_drivers is None:
            raise RuntimeError(
                f"Compute `.{P.LIN_DRIVERS}` first as `.compute_lineage_drivers()`."
            )

        key = f"{lineage} corr"
        if key not in lin_drivers:
            raise KeyError(
                f"Lineage `{key!r}` not found in `{list(lin_drivers.columns)}`."
            )

        if n_genes <= 0:
            raise ValueError(f"Expected `n_genes` to be positive, found `{n_genes}`.")

        kwargs.pop("save", None)
        genes = lin_drivers.sort_values(by=key, ascending=False).head(n_genes)

        ncols = 4 if ncols is None else ncols
        nrows = int(np.ceil(len(genes) / ncols))

        fig, axes = plt.subplots(
            ncols=ncols,
            nrows=nrows,
            dpi=dpi,
            figsize=(ncols * 6, nrows * 4) if figsize is None else figsize,
        )
        axes = np.ravel([axes])

        _i = 0
        for _i, (gene, ax) in enumerate(zip(genes.index, axes)):
            data = genes.loc[gene]
            scv.pl.scatter(
                self.adata,
                color=gene,
                ncols=ncols,
                use_raw=use_raw,
                ax=ax,
                show=False,
                title=title_fmt.format(
                    **prepare_format(
                        gene,
                        pval=data.get(f"{lineage} pval", None),
                        qval=data.get(f"{lineage} qval", None),
                        corr=data.get(f"{lineage} corr", None),
                    )
                ),
                **kwargs,
            )

        for j in range(_i + 1, len(axes)):
            axes[j].remove()

        if save is not None:
            save_fig(fig, save)
Beispiel #14
0
    def plot(
        self,
        figsize: Tuple[float, float] = (15, 10),
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Tuple[float, float] = None,
        abs_prob_cmap: mcolors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        lineage_color: str = "black",
        alpha: float = 0.8,
        lineage_alpha: float = 0.2,
        title: Optional[str] = None,
        size: int = 15,
        lw: float = 2,
        show_cbar: bool = True,
        margins: float = 0.015,
        xlabel: str = "pseudotime",
        ylabel: str = "expression",
        show_conf_int: bool = True,
        dpi: int = None,
        fig: mpl.figure.Figure = None,
        ax: mpl.axes.Axes = None,
        return_fig: bool = False,
        save: Optional[str] = None,
    ) -> Optional[mpl.figure.Figure]:
        """
        Plot the smoothed gene expression.

        Parameters
        ----------
        figsize
            Size of the figure.
        same_plot
            Whether to plot all trends in the same plot.
        hide_cells
            Whether to hide the cells.
        perc
            Percentile by which to clip the absorption probabilities./
        abs_prob_cmap
            Colormap to use when coloring in the absorption probabilities.
        cell_color
            Color for the cells when not coloring absorption probabilities.
        lineage_color
            Color for the lineage.
        alpha
            Alpha channel for cells.
        lineage_alpha
            Alpha channel for lineage confidence intervals.
        title
            Title of the plot.
        size
            Size of the points.
        lw
            Line width for the smoothed values.
        show_cbar
            Whether to show colorbar.
        margins
            Margins around the plot.
        xlabel
            Label on the x-axis.
        ylabel
            Label on the y-axis.
        show_conf_int
            Whether to show the confidence interval.
        dpi
            Dots per inch.
        fig
            Figure to use, if `None`, create a new one.
        ax: :class:`matplotlib.axes.Axes`
            Ax to use, if `None`, create a new one.
        return_fig
            If `True`, return the figure object.
        save
            Filename where to save the plot. If `None`, just shows the plots.

        Returns
        -------
        %(just_plots)s
        """

        if fig is None or ax is None:
            fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)

        if dpi is not None:
            fig.set_dpi(dpi)

        vmin, vmax = _minmax(self.w, perc)
        if not hide_cells:
            _ = ax.scatter(
                self.x_all.squeeze(),
                self.y_all.squeeze(),
                c=cell_color if same_plot or np.allclose(self.w_all, 1.0) else
                self.w_all.squeeze(),
                s=size,
                cmap=abs_prob_cmap,
                vmin=vmin,
                vmax=vmax,
                alpha=alpha,
            )

        if title is None:
            title = f"{self._gene} @ {self._lineage}"

        _ = ax.plot(self.x_test,
                    self.y_test,
                    color=lineage_color,
                    lw=lw,
                    label=title)

        ax.set_title(title)
        ax.set_ylabel(ylabel)
        ax.set_xlabel(xlabel)

        ax.margins(margins)

        if show_conf_int and self.conf_int is not None:
            ax.fill_between(
                self.x_test.squeeze(),
                self.conf_int[:, 0],
                self.conf_int[:, 1],
                alpha=lineage_alpha,
                color=lineage_color,
                linestyle="--",
            )

        if (show_cbar and not hide_cells and not same_plot
                and not np.allclose(self.w_all, 1)):
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="2.5%", pad=0.1)
            _ = mpl.colorbar.ColorbarBase(cax,
                                          norm=norm,
                                          cmap=abs_prob_cmap,
                                          label="absorption probability")

        if save is not None:
            save_fig(fig, save)

        if return_fig:
            return fig
Beispiel #15
0
    def plot_schur_matrix(
        self,
        title: Optional[str] = "schur matrix",
        cmap: str = "viridis",
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[float] = 80,
        save: Optional[Union[str, Path]] = None,
        **kwargs,
    ):
        """
        Plot the Schur matrix.

        Parameters
        ----------
        title
            Title of the figure.
        cmap
            Colormap to use.
        %(plotting)s
        **kwargs
            Keyword arguments for :func:`seaborn.heatmap`.

        Returns
        -------
        %(just_plots)s
        """

        from seaborn import heatmap

        schur_matrix = getattr(self, P.SCHUR_MAT.s)

        if schur_matrix is None:
            raise RuntimeError(
                f"Compute Schur matrix first as `.{F.COMPUTE.fmt(P.SCHUR)}()`."
            )

        fig, ax = plt.subplots(
            figsize=schur_matrix.shape if figsize is None else figsize,
            dpi=dpi)

        divider = make_axes_locatable(
            ax)  # square=True make the colorbar a bit bigger
        cbar_ax = divider.append_axes("right", size="2%", pad=0.1)

        mask = np.zeros_like(schur_matrix, dtype=np.bool)
        mask[np.tril_indices_from(mask, k=-1)] = True
        mask[~np.isclose(schur_matrix, 0.0)] = False

        vmin, vmax = (
            np.min(schur_matrix[~mask]),
            np.max(schur_matrix[~mask]),
        )

        kwargs["fmt"] = kwargs.get("fmt", "0.2f")
        heatmap(
            schur_matrix,
            cmap=cmap,
            square=True,
            annot=True,
            vmin=vmin,
            vmax=vmax,
            cbar_ax=cbar_ax,
            cbar_kws={"ticks": np.linspace(vmin, vmax, 10)},
            mask=mask,
            xticklabels=[],
            yticklabels=[],
            ax=ax,
            **kwargs,
        )

        ax.set_title(title)

        if save is not None:
            save_fig(fig, path=save)
Beispiel #16
0
def cluster_fates(
    adata: AnnData,
    mode: str = ClusterFatesMode.PAGA_PIE.s,
    backward: bool = False,
    lineages: Optional[Union[str, Sequence[str]]] = None,
    cluster_key: Optional[str] = "clusters",
    clusters: Optional[Union[str, Sequence[str]]] = None,
    basis: Optional[str] = None,
    cbar: bool = True,
    ncols: Optional[int] = None,
    sharey: bool = False,
    fmt: str = "0.2f",
    xrot: float = 90,
    legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}),
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> None:
    """
    Plot aggregate lineage probabilities at a cluster level.

    This can be used to investigate how likely a certain cluster is to go to the %(terminal)s states,or in turn to have
    descended from the %(initial)s states. For mode `{m.PAGA.s!r}` and `{m.PAGA_PIE.s!r}`, we use *PAGA*, see [Wolf19]_.

    Parameters
    ----------
    %(adata)s
    mode
        Type of plot to show. Valid options are:

            - `{m.BAR.s!r}` - barplot, one panel per cluster.
            - `{m.PAGA.s!r}` - scanpy's PAGA, one per %(initial_or_terminal)s state, colored in by fate.
            - `{m.PAGA_PIE.s!r}` - scanpy's PAGA with pie charts indicating aggregated fates.
            - `{m.VIOLIN.s!r}` - violin plots, one per %(initial_or_terminal)s state.
            - `{m.HEATMAP.s!r}` - a heatmap, showing average fates per cluster.
            - `{m.CLUSTERMAP.s!r}` - same as a heatmap, but with a dendrogram.
    %(backward)s
    lineages
        Lineages for which to visualize absorption probabilities. If `None`, use all lineages.
    cluster_key
        Key in ``adata.obs`` containing the clusters.
    clusters
        Clusters to visualize. If `None`, all clusters will be plotted.
    basis
        Basis for scatterplot to use when ``mode={m.PAGA_PIE.s!r}``. If `None`, don't show the scatterplot.
    cbar
        Whether to show colorbar when ``mode={m.PAGA_PIE.s!r}``.
    ncols
        Number of columns when ``mode={m.BAR.s!r}`` or ``mode={m.PAGA.s!r}``.
    sharey
        Whether to share y-axis when ``mode={m.BAR.s!r}``.
    fmt
        Format when using ``mode={m.HEATMAP.s!r}`` or ``mode={m.CLUSTERMAP.s!r}``.
    xrot
        Rotation of the labels on the x-axis.
    figsize
        Size of the figure.
    legend_kwargs
        Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position.
        For ``mode={m.PAGA_PIE.s!r}`` and ``basis='...'``, this controls the placement of the
        absorption probabilities legend.
    %(plotting)s
    **kwargs
        Keyword arguments for :func:`scvelo.pl.paga`, :func:`scanpy.pl.violin` or :func:`matplotlib.pyplot.bar`,
        depending on the value of ``mode``.

    Returns
    -------
    %(just_plots)s
    """

    from scanpy.plotting import violin
    from scvelo.plotting import paga

    from seaborn import heatmap, clustermap

    @valuedispatch
    def plot(mode: ClusterFatesMode, *_args, **_kwargs):
        raise NotImplementedError(mode.value)

    @plot.register(ClusterFatesMode.BAR)
    def _():
        cols = 4 if ncols is None else ncols
        n_rows = ceil(len(clusters) / cols)
        fig = plt.figure(None, (3.5 * cols,
                                5 * n_rows) if figsize is None else figsize,
                         dpi=dpi)
        fig.tight_layout()

        gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.5, hspace=0.5)

        ax = None
        colors = list(adata.obsm[lk][:, lin_names].colors)

        for g, k in zip(gs, d.keys()):
            current_ax = fig.add_subplot(g, sharey=ax)
            current_ax.bar(
                x=np.arange(len(lin_names)),
                height=d[k][0],
                color=colors,
                yerr=d[k][1],
                ecolor="black",
                capsize=10,
                **kwargs,
            )
            if sharey:
                ax = current_ax

            current_ax.set_xticks(np.arange(len(lin_names)))
            current_ax.set_xticklabels(lin_names, rotation=xrot)
            if not is_all:
                current_ax.set_xlabel(points)
            current_ax.set_ylabel("absorption probability")
            current_ax.set_title(k)

        return fig

    @plot.register(ClusterFatesMode.PAGA)
    def _():
        kwargs["save"] = None
        kwargs["show"] = False
        if "cmap" not in kwargs:
            kwargs["cmap"] = cm.viridis

        cols = len(lin_names) if ncols is None else ncols
        nrows = ceil(len(lin_names) / cols)
        fig, axes = plt.subplots(
            nrows,
            cols,
            figsize=(7 * cols, 4 * nrows) if figsize is None else figsize,
            constrained_layout=True,
            dpi=dpi,
        )
        # fig.tight_layout()  can't use this because colorbar.make_axes fails

        i = 0
        axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes)
        vmin, vmax = np.inf, -np.inf

        if basis is not None:
            kwargs["basis"] = basis
            kwargs["scatter_flag"] = True
            kwargs["color"] = cluster_key

        for i, (ax, lineage_name) in enumerate(zip(axes, lin_names)):
            colors = [v[0][i] for v in d.values()]
            kwargs["ax"] = ax
            kwargs["colors"] = tuple(colors)
            kwargs["title"] = f"{dir_prefix} {lineage_name}"

            vmin = np.min(colors + [vmin])
            vmax = np.max(colors + [vmax])

            paga(adata, **kwargs)

        if cbar:
            norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
            cax, _ = mpl.colorbar.make_axes(ax, aspect=60)
            _ = mpl.colorbar.ColorbarBase(
                cax,
                ticks=np.linspace(norm.vmin, norm.vmax, 5),
                norm=norm,
                cmap=kwargs["cmap"],
                label="average absorption probability",
            )

        for ax in axes[i + 1:]:  # noqa
            ax.remove()

        return fig

    @plot.register(ClusterFatesMode.PAGA_PIE)
    def _():
        colors = list(adata.obsm[lk][:, lin_names].colors)
        colors = {
            i: odict(zip(colors, mean))
            for i, (mean, _) in enumerate(d.values())
        }

        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        fig.tight_layout()

        kwargs["ax"] = ax
        kwargs["show"] = False
        kwargs["colorbar"] = False  # has to be disabled
        kwargs["show"] = False

        kwargs["node_colors"] = colors
        kwargs.pop("save", None)  # we will handle saving

        kwargs["transitions"] = kwargs.get("transitions",
                                           "transitions_confidence")
        if "legend_loc" in kwargs:
            orig_ll = kwargs["legend_loc"]
            if orig_ll != "on data":
                kwargs["legend_loc"] = "none"  # we will handle legend
        else:
            orig_ll = None
            kwargs["legend_loc"] = "on data"

        if basis is not None:
            kwargs["basis"] = basis
            kwargs["scatter_flag"] = True
            kwargs["color"] = cluster_key

        ax = paga(adata, **kwargs)
        ax.set_title(kwargs.get("title", cluster_key))

        if basis is not None and orig_ll not in ("none", "on data", None):
            handles = []
            for cluster_name, color in zip(
                    adata.obs[f"{cluster_key}"].cat.categories,
                    adata.uns[f"{cluster_key}_colors"],
            ):
                handles += [ax.scatter([], [], label=cluster_name, c=color)]
            first_legend = _position_legend(
                ax,
                legend_loc=orig_ll,
                handles=handles,
                **{k: v
                   for k, v in legend_kwargs.items() if k != "loc"},
                title=cluster_key,
            )
            fig.add_artist(first_legend)

        if legend_kwargs.get("loc", None) not in ("none", "on data", None):
            # we need to use these, because scvelo can have its own handles and
            # they would be plotted here
            handles = []
            for lineage_name, color in zip(lin_names, colors[0].keys()):
                handles += [ax.scatter([], [], label=lineage_name, c=color)]
            if len(colors[0].keys()) != len(adata.obsm[lk].names):
                handles += [ax.scatter([], [], label="Rest", c="grey")]

            second_legend = _position_legend(
                ax,
                legend_loc=legend_kwargs["loc"],
                handles=handles,
                **{k: v
                   for k, v in legend_kwargs.items() if k != "loc"},
                title=points,
            )
            fig.add_artist(second_legend)

        return fig

    @plot.register(ClusterFatesMode.VIOLIN)
    def _():
        kwargs.pop("ax", None)
        kwargs.pop("keys", None)
        kwargs.pop("save", None)  # we will handle saving

        kwargs["show"] = False
        kwargs["groupby"] = cluster_key
        kwargs["rotation"] = xrot

        cols = len(lin_names) if ncols is None else ncols
        nrows = ceil(len(lin_names) / cols)

        fig, axes = plt.subplots(
            nrows,
            cols,
            figsize=(6 * cols, 4 * nrows) if figsize is None else figsize,
            sharey=sharey,
            dpi=dpi,
        )
        fig.tight_layout()
        fig.subplots_adjust(wspace=0.2, hspace=0.3)

        if not isinstance(axes, np.ndarray):
            axes = [axes]
        axes = np.ravel(axes)

        with RandomKeys(adata, len(lin_names), where="obs") as keys:
            _i = 0
            for _i, (name, key, ax) in enumerate(zip(lin_names, keys, axes)):
                adata.obs[key] = adata.obsm[lk][name].X
                ax.set_title(f"{dir_prefix} {name}")
                violin(adata,
                       ylabel="absorption probability",
                       keys=key,
                       ax=ax,
                       **kwargs)
            for ax in axes[_i + 1:]:  # noqa
                ax.remove()

        return fig

    def plot_violin_no_cluster_key():
        from anndata import AnnData as _AnnData

        kwargs.pop("ax", None)
        kwargs.pop("keys", None)  # don't care
        kwargs.pop("save", None)

        kwargs["show"] = False
        kwargs["groupby"] = points
        kwargs["xlabel"] = None
        kwargs["rotation"] = xrot

        data = np.ravel(adata.obsm[lk].X.T)[..., np.newaxis]
        tmp = _AnnData(csr_matrix(data.shape, dtype=np.float32))
        tmp.obs["absorption probability"] = data
        tmp.obs[points] = (pd.Series(
            np.concatenate([[f"{dir_prefix.lower()} {n}"] * adata.n_obs
                            for n in adata.obsm[lk].names
                            ])).astype("category").values)
        tmp.obs[points].cat.reorder_categories(
            [f"{dir_prefix.lower()} {n}" for n in adata.obsm[lk].names],
            inplace=True)
        tmp.uns[f"{points}_colors"] = adata.obsm[lk].colors

        fig, ax = plt.subplots(figsize=figsize if figsize is not None else
                               (8, 6),
                               dpi=dpi)
        ax.set_title(points.capitalize())

        violin(tmp, keys=["absorption probability"], ax=ax, **kwargs)

        return fig

    @plot.register(ClusterFatesMode.HEATMAP)
    def _():
        data = pd.DataFrame([mean for mean, _ in d.values()],
                            columns=lin_names,
                            index=clusters).T

        title = kwargs.pop("title", "average fate per cluster")
        vmin, vmax = data.values.min(), data.values.max()
        cbar_kws = {
            "label": "probability",
            "ticks": np.linspace(vmin, vmax, 5),
            "format": "%.3f",
        }
        kwargs.setdefault("cmap", "viridis")

        if use_clustermap:
            kwargs["cbar_pos"] = (0, 0.9, 0.025, 0.15) if cbar else None
            max_size = float(max(data.shape))

            g = clustermap(
                data,
                annot=True,
                vmin=vmin,
                vmax=vmax,
                fmt=fmt,
                row_colors=adata.obsm[lk][lin_names].colors,
                dendrogram_ratio=(
                    0.15 * data.shape[0] / max_size,
                    0.15 * data.shape[1] / max_size,
                ),
                cbar_kws=cbar_kws,
                figsize=figsize,
                **kwargs,
            )
            g.ax_heatmap.set_xlabel(cluster_key)
            g.ax_heatmap.set_ylabel("lineage")
            g.ax_col_dendrogram.set_title(title)

            fig = g.fig
            g = g.ax_heatmap
        else:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
            g = heatmap(
                data,
                vmin=vmin,
                vmax=vmax,
                annot=True,
                fmt=fmt,
                cbar=cbar,
                cbar_kws=cbar_kws,
                ax=ax,
                **kwargs,
            )
            ax.set_title(title)
            ax.set_xlabel(cluster_key)
            ax.set_ylabel("lineage")

        g.set_xticklabels(g.get_xticklabels(), rotation=xrot)
        g.set_yticklabels(g.get_yticklabels(), rotation=0)

        return fig

    mode = ClusterFatesMode(mode)

    if cluster_key is not None:
        if cluster_key not in adata.obs:
            raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.")
    elif mode not in (mode.BAR, mode.VIOLIN):
        raise ValueError(
            f"Not specifying cluster key is only available for modes "
            f"`{ClusterFatesMode.BAR!r}` and `{ClusterFatesMode.VIOLIN!r}`, found `mode={mode!r}`."
        )

    if backward:
        lk = AbsProbKey.BACKWARD.s
        points = TerminalStatesPlot.BACKWARD.s
        dir_prefix = DirPrefix.BACKWARD.s
    else:
        lk = AbsProbKey.FORWARD.s
        points = TerminalStatesPlot.FORWARD.s
        dir_prefix = DirPrefix.FORWARD.s

    if cluster_key is not None:
        is_all = False
        if clusters is not None:
            if isinstance(clusters, str):
                clusters = [clusters]
            clusters = _unique_order_preserving(clusters)
            if mode in (mode.PAGA, mode.PAGA_PIE):
                logg.debug(
                    f"Setting `clusters` to all available ones because of `mode={mode!r}`"
                )
                clusters = list(adata.obs[cluster_key].cat.categories)
            else:
                for cname in clusters:
                    if cname not in adata.obs[cluster_key].cat.categories:
                        raise KeyError(
                            f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`."
                        )
        else:
            clusters = list(adata.obs[cluster_key].cat.categories)
    else:
        is_all = True
        clusters = [points]

    if lk not in adata.obsm:
        raise KeyError(f"Lineage key `{lk!r}` not found in `adata.obsm`.")

    if lineages is not None:
        if isinstance(lineages, str):
            lineages = [lineages]
        lin_names = _unique_order_preserving(lineages)
    else:
        # must be list for `sc.pl.violin`, else cats str
        lin_names = list(adata.obsm[lk].names)
    _ = adata.obsm[lk][lin_names]

    if mode == mode.VIOLIN and not is_all:
        adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy()

    d = odict()
    for name in clusters:
        mask = (np.ones((adata.n_obs, ), dtype=np.bool) if is_all else
                (adata.obs[cluster_key] == name).values)
        mask = np.array(mask, dtype=np.bool)
        data = adata.obsm[lk][mask, lin_names].X
        mean = np.nanmean(data, axis=0)
        std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0])
        d[name] = [mean, std]

    logg.debug(f"Plotting in mode `{mode!r}`")
    use_clustermap = False
    if mode == mode.CLUSTERMAP:
        use_clustermap = True
        mode = mode.HEATMAP
    elif (mode in (ClusterFatesMode.PAGA, ClusterFatesMode.PAGA_PIE)
          and "paga" not in adata.uns):
        raise KeyError("Compute PAGA first as `scvelo.tl.paga()`.")

    fig = (plot_violin_no_cluster_key() if mode == ClusterFatesMode.VIOLIN
           and cluster_key is None else plot(mode))

    if save is not None:
        save_fig(fig, save)

    fig.show()
Beispiel #17
0
def log_odds(
    adata: AnnData,
    lineage_1: str,
    lineage_2: Optional[str] = None,
    time_key: str = "exp_time",
    backward: bool = False,
    keys: Optional[Union[str, Sequence[str]]] = None,
    threshold: Optional[Union[float, Sequence]] = None,
    threshold_color: str = "red",
    layer: Optional[str] = None,
    use_raw: bool = False,
    size: float = 2.0,
    cmap: str = "viridis",
    alpha: Optional[float] = 0.8,
    ncols: Optional[int] = None,
    fontsize: Optional[Union[float, str]] = None,
    xticks_step_size: Optional[int] = 1,
    legend_loc: Optional[str] = "best",
    jitter: Union[bool, float] = True,
    seed: Optional[int] = None,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    show: bool = True,
    **kwargs: Any,
) -> Optional[Union[Axes, Sequence[Axes]]]:
    """
    Plot log-odds ratio between lineages.

    Log-odds are plotted as a function of the experimental time.

    Parameters
    ----------
    %(adata)s
    lineage_1
        The first lineage for which to compute the log-odds.
    lineage_2
        The second lineage for which to compute the log-odds. If `None`, use the rest of the lineages.
    time_key
        Key in :attr:`anndata.AnnData.obs` containing the experimental time.
    %(backward)s
    keys
        Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names`.
    threshold
        Visualize whether total expression per cell is greater than ``threshold``.
        If a :class:`typing.Sequence`, it should be the same length as ``keys``.
    threshold_color
        Color to use when plotting thresholded expression values.
    layer
        Which layer to use to get expression values. If `None` or `'X'`, use :attr:`anndata.AnnData.X`.
    use_raw
        Whether to access :attr:`anndata.AnnData.raw`. If `True`, ``layer`` is ignored.
    size
        Size of the dots.
    cmap
        Colormap to use for continuous variables in ``keys``.
    alpha
        Alpha values for the dots.
    ncols
        Number of columns.
    fontsize
        Size of the font for the title, x- and y-label.
    xticks_step_size
        Show only every n-th ticks on x-axis. If `None`, don't show any ticks.
    legend_loc
        Position of the legend. If `None`, do not show the legend.
    jitter
        Amount of jitter to apply along x-axis.
    seed
        Seed for ``jitter`` to ensure reproducibility.
    %(plotting)s
    show
        If `False`, return :class:`matplotlib.pyplot.Axes` or a sequence of them.
    kwargs
        Keyword arguments for :func:`seaborn.stripplot`.

    Returns
    -------
    :class:`matplotlib.pyplot.Axes`
        The axis object(s) if ``show=False``.
    %(just_plots)s
    """
    from cellrank.tl.kernels._utils import _ensure_numeric_ordered

    def decorate(ax: Axes,
                 *,
                 title: Optional[str] = None,
                 show_ylabel: bool = True) -> None:
        ax.set_xlabel(time_key, fontsize=fontsize)
        ax.set_title(title, fontdict={"fontsize": fontsize})
        ax.set_ylabel(ylabel if show_ylabel else "", fontsize=fontsize)

        if xticks_step_size is None:
            ax.set_xticks([])
        else:
            step = max(1, xticks_step_size)
            ax.set_xticks(np.arange(0, n_cats, step))
            ax.set_xticklabels(df[time_key].cat.categories[::step])

    def cont_palette(values: np.ndarray) -> Tuple[np.ndarray, ScalarMappable]:
        cm = copy(plt.get_cmap(cmap))
        cm.set_bad("grey")
        sm = ScalarMappable(cmap=cm,
                            norm=Normalize(vmin=np.nanmin(values),
                                           vmax=np.nanmax(values)))
        return np.array([to_hex(v) for v in (sm.to_rgba(values))]), sm

    def get_data(
        key: str,
        thresh: Optional[float] = None,
    ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[np.ndarray],
               ScalarMappable]:
        try:
            _, palette = _get_categorical_colors(adata, key)
            df[key] = adata.obs[key].values[mask]
            df[key] = df[key].cat.remove_unused_categories()
            try:
                # seaborn doesn't like numeric categories
                df[key] = df[key].astype(float)
                palette = {float(k): v for k, v in palette.items()}
            except ValueError:
                pass
            # otherwise seaborn plots all
            palette = {k: palette[k] for k in df[key].unique()}
            hue, thresh_mask, sm = key, None, None
        except TypeError:
            palette, hue, thresh_mask, sm = (
                cont_palette(adata.obs[key].values[mask])[0],
                None,
                None,
                None,
            )
        except KeyError:
            try:
                # fmt: off
                if thresh is None:
                    values = adata.raw.obs_vector(
                        key) if use_raw else adata.obs_vector(key, layer=layer)
                    palette, sm = cont_palette(values)
                    hue, thresh_mask = None, None
                else:
                    if use_raw:
                        values = np.asarray(
                            adata.raw[:, key].X[mask].sum(1)).squeeze()
                    elif layer not in (None, "X"):
                        values = np.asarray(
                            adata[:,
                                  key].layers[layer][mask].sum(1)).squeeze()
                    else:
                        values = np.asarray(
                            adata[:, key].X[mask].sum(1)).squeeze()
                    thresh_mask = values > thresh
                    hue, palette, sm = None, None, None
                # fmt: on
            except KeyError as e:
                raise e from None

        return hue, palette, thresh_mask, sm

    np.random.seed(seed)
    _ = kwargs.pop("orient", None)
    if use_raw and adata.raw is None:
        logg.warning("No raw attribute set. Setting `use_raw=False`")
        use_raw = False

    # define log-odds
    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")
    time = _ensure_numeric_ordered(adata, time_key)
    order = time.cat.categories[::-1 if backward else 1]

    fate1 = adata.obsm[ln_key][lineage_1].X.squeeze(-1)
    if lineage_2 is None:
        fate2 = 1 - fate1
        ylabel = rf"$\log{{\frac{{{lineage_1}}}{{rest}}}}$"
    else:
        fate2 = adata.obsm[ln_key][lineage_2].X.squeeze(-1)
        ylabel = rf"$\log{{\frac{{{lineage_1}}}{{{lineage_2}}}}}$"

    # fmt: off
    df = pd.DataFrame({
        "log_odds":
        np.log(
            np.divide(fate1, fate2, where=fate2 != 0, out=np.zeros_like(fate1))
            + 1e-12),
        time_key:
        time,
    })
    mask = (fate1 != 0) & (fate2 != 0)
    df = df[mask]
    n_cats = len(df[time_key].cat.categories)
    # fmt: on

    if keys is None:
        if figsize is None:
            figsize = np.array([n_cats, n_cats * 4 / 6]) / 2

        fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
        ax = sns.stripplot(
            time_key,
            "log_odds",
            data=df,
            order=order,
            jitter=jitter,
            color="k",
            size=size,
            ax=ax,
            **kwargs,
        )
        decorate(ax)
        if save is not None:
            save_fig(fig, save)
        return None if show else ax

    if isinstance(keys, str):
        keys = (keys, )
    if not len(keys):
        raise ValueError("No keys have been selected.")
    keys = _unique_order_preserving(keys)

    if not isinstance(threshold, Iterable):
        threshold = (threshold, ) * len(keys)
    if len(threshold) != len(keys):
        raise ValueError(
            f"Expected `threshold` to be of length `{len(keys)}`, found `{len(threshold)}`."
        )

    ncols = max(len(keys) if ncols is None else ncols, 1)
    nrows = int(np.ceil(len(keys) / ncols))
    if figsize is None:
        figsize = np.array([n_cats * ncols, n_cats * nrows * 4 / 6]) / 2

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=figsize,
        dpi=dpi,
        constrained_layout=True,
        sharey="all",
    )
    axes = np.ravel([axes])

    i = 0
    for i, (key, ax, thresh) in enumerate(zip(keys, axes, threshold)):
        hue, palette, thresh_mask, sm = get_data(key, thresh)
        show_ylabel = i % ncols == 0

        ax = sns.stripplot(
            time_key,
            "log_odds",
            data=df if thresh_mask is None else df[~thresh_mask],
            hue=hue,
            order=order,
            jitter=jitter,
            color="black",
            palette=palette,
            size=size,
            alpha=alpha
            if alpha is not None else None if thresh_mask is None else 0.8,
            ax=ax,
            **kwargs,
        )
        if thresh_mask is not None:
            sns.stripplot(
                time_key,
                "log_odds",
                data=df if thresh_mask is None else df[thresh_mask],
                hue=hue,
                order=order,
                jitter=jitter,
                color=threshold_color,
                palette=palette,
                size=size * 2,
                alpha=0.9,
                ax=ax,
                **kwargs,
            )
            key = rf"${key} > {thresh}$"
        if sm is not None:
            cax = ax.inset_axes([1.02, 0, 0.025, 1], transform=ax.transAxes)
            fig.colorbar(sm, ax=ax, cax=cax)
        else:
            if legend_loc in (None, "none"):
                legend = ax.get_legend()
                if legend is not None:
                    legend.remove()
            else:
                handles, labels = ax.get_legend_handles_labels()
                if len(handles):
                    _position_legend(ax,
                                     legend_loc=legend_loc,
                                     handles=handles,
                                     labels=labels)

        decorate(ax, title=key, show_ylabel=show_ylabel)

    for ax in axes[i + 1:]:
        ax.remove()
    axes = axes[:i + 1]

    if save is not None:
        save_fig(fig, save)

    return None if show else axes[0] if len(axes) == 1 else axes
Beispiel #18
0
    def plot_pie(
        self,
        reduction: Callable,
        title: Optional[str] = None,
        legend_loc: Optional[str] = "on data",
        legend_kwargs: Mapping = MappingProxyType({}),
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[float] = None,
        save: Optional[Union[Path, str]] = None,
        **kwargs,
    ) -> None:
        """
        Plot a pie chart visualizing aggregated lineage probabilities.

        Parameters
        ----------
        reduction
            Function that will be applied lineage-wise.
        title
            Title of the figure.
        legend_loc
            Location of the legend. If `None`, it is not shown.
        legend_kwargs
            Keyword arguments for :func:`matplotlib.axes.Axes.legend`.
        %(plotting)s

        Returns
        -------
        %(just_plots)s
        """

        if len(self.names) == 1:
            raise ValueError("Cannot plot pie chart for only 1 lineage.")

        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

        if "autopct" not in kwargs:
            autopct_found = False
            autopct = (
                "{:.1f}%".format
            )  # we don't really care, we don't shot the pct, but the value
        else:
            autopct_found = True
            autopct = kwargs.pop("autopct")

        if title is None:
            title = reduction.__name__ if hasattr(reduction,
                                                  "__name__") else None

        reduction = reduction(self, axis=int(self._is_transposed)).X.squeeze()
        reduction_norm = reduction / np.sum(reduction)

        wedges, texts, *autotexts = ax.pie(
            reduction_norm.squeeze(),
            labels=self.names if legend_loc == "on data" else None,
            autopct=autopct,
            wedgeprops={"edgecolor": "w"},
            colors=self.colors,
            **kwargs,
        )

        # if autopct is not None
        if len(autotexts):
            autotexts = autotexts[0]
            for name, at in zip(self.names, autotexts):
                ix = self._names_to_ixs[name]
                at.set_color(_get_bg_fg_colors(self.colors[ix])[1])
                if not autopct_found:
                    at.set_text(f"{reduction[ix]:.4f}")

        if legend_loc not in (None, "none", "on data"):
            ax.legend(
                wedges,
                self.names,
                title="lineages",
                loc=legend_loc,
                **legend_kwargs,
            )

        ax.set_title(title)
        ax.set_aspect("equal")

        fig.show()

        if save is not None:
            save_fig(fig, save)
Beispiel #19
0
def cluster_lineage(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineage: str,
    backward: bool = False,
    time_range: _time_range_type = None,
    clusters: Optional[Sequence[str]] = None,
    n_points: int = 200,
    time_key: str = "latent_time",
    norm: bool = True,
    recompute: bool = False,
    callback: _callback_type = None,
    ncols: int = 3,
    sharey: Union[str, bool] = False,
    key: Optional[str] = None,
    random_state: Optional[int] = None,
    use_leiden: bool = False,
    show_progress_bar: bool = True,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
    neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
    clustering_kwargs: Dict = MappingProxyType({}),
    return_models: bool = False,
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential
    lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when
        plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key
        Key in ``adata.uns`` where to save the results. If `None`, it will be saved as ``lineage_{lineage}_trend`` .
    random_state
        Random seed for reproducibility.
    use_leiden
        Whether to use :func:`scanpy.tl.leiden` for clustering or :func:`scanpy.tl.louvain`.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    clustering_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain` or :func:`scanpy.tl.leiden`.
    %(return_models)s
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s

        Also updates ``adata.uns`` with the following:

            - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of
              shape `(n_genes, n_points)` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    if key is None:
        key = f"lineage_{lineage}_trend"

    if recompute or key not in adata.uns:
        kwargs["backward"] = backward
        kwargs["time_key"] = time_key
        kwargs["n_test_points"] = n_points
        models = _create_models(model, genes, [lineage])
        all_models, models, genes, _ = _fit_bulk(
            models,
            _create_callbacks(adata, callback, genes, [lineage], **kwargs),
            genes,
            lineage,
            time_range,
            return_models=True,  # always return (better error messages)
            filter_all_failed=True,
            parallel_kwargs={
                "show_progress_bar": show_progress_bar,
                "n_jobs": _get_n_cores(n_jobs, len(genes)),
                "backend": _get_backend(models, backend),
            },
            **kwargs,
        )

        # `n_genes, n_test_points`
        trends = np.vstack(
            [model[lineage].y_test for model in models.values()]).T

        if norm:
            logg.debug("Normalizing trends")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        random_state = np.random.mtrand.RandomState(random_state).randint(
            2**16)

        pca_kwargs = dict(pca_kwargs)
        pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1)
        pca_kwargs.setdefault("random_state", random_state)
        sc.pp.pca(trends, **pca_kwargs)

        neighbors_kwargs = dict(neighbors_kwargs)
        neighbors_kwargs.setdefault("random_state", random_state)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        clustering_kwargs = dict(clustering_kwargs)
        clustering_kwargs["key_added"] = "clusters"
        clustering_kwargs.setdefault("random_state", random_state)
        try:
            if use_leiden:
                sc.tl.leiden(trends, **clustering_kwargs)
            else:
                sc.tl.louvain(trends, **clustering_kwargs)
        except ImportError as e:
            logg.warning(str(e))
            if use_leiden:
                sc.tl.louvain(trends, **clustering_kwargs)
            else:
                sc.tl.leiden(trends, **clustering_kwargs)

        logg.info(f"Saving data to `adata.uns[{key!r}]`")
        adata.uns[key] = trends
    else:
        all_models = None
        logg.info(f"Loading data from `adata.uns[{key!r}]`")
        trends = adata.uns[key]

    if "clusters" not in trends.obs:
        raise KeyError(
            "Unable to find the clustering in `trends.obs['clusters']`.")

    if clusters is None:
        clusters = trends.obs["clusters"].cat.categories
    for c in clusters:
        if c not in trends.obs["clusters"].cat.categories:
            raise ValueError(
                f"Invalid cluster name `{c!r}`. "
                f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`."
            )

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs["clusters"] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models
Beispiel #20
0
def gene_trends(
        adata: AnnData,
        model: _model_type,
        genes: Union[str, Sequence[str]],
        lineages: Optional[Union[str, Sequence[str]]] = None,
        backward: bool = False,
        data_key: str = "X",
        time_key: str = "latent_time",
        time_range: Optional[Union[_time_range_type,
                                   List[_time_range_type]]] = None,
        callback: _callback_type = None,
        conf_int: bool = True,
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Optional[Union[Tuple[float, float],
                             Sequence[Tuple[float, float]]]] = None,
        lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
        abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        cell_alpha: float = 0.6,
        lineage_alpha: float = 0.2,
        size: float = 15,
        lw: float = 2,
        show_cbar: bool = True,
        margins: float = 0.015,
        sharex: Optional[Union[str, bool]] = None,
        sharey: Optional[Union[str, bool]] = None,
        gene_as_title: Optional[bool] = None,
        legend_loc: Optional[str] = "best",
        ncols: int = 2,
        suptitle: Optional[str] = None,
        n_jobs: Optional[int] = 1,
        backend: str = _DEFAULT_BACKEND,
        show_progres_bar: bool = True,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        plot_kwargs: Mapping = MappingProxyType({}),
        **kwargs,
) -> None:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression,
    where we take the lineage weights into account in the loss function.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages to plot. If `None`, plot all lineages.
    %(backward)s
    data_key
        Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    conf_int
        Whether to compute and show confidence intervals.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all cells.
    perc
        Percentile for colors. Valid values are in interval `[0, 100]`.
        This can improve visualization. Can be specified individually for each lineage.
    lineage_cmap
        Colormap to use when coloring in the lineages. If `None` and ``same_plot``, use the corresponding colors
        in ``adata.uns``, otherwise use `'black'`.
    abs_prob_cmap
        Colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when ``same_plot=False``.
    cell_color
        Color of the cells when not visualizing absorption probabilities. Only used when ``same_plot=True``.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    show_cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``.
    margins
        Margins around the plot.
    sharex
        Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`.
    sharey
        Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`.
    gene_as_title
        Whether to show gene names as titles instead on y-axis.
    legend_loc
        Location of the legend displaying lineages. Only used when `same_plot=True`.
    ncols
        Number of columns of the plot when pl multiple genes. Only used when ``same_plot=True``.
    suptitle
        Suptitle of the figure.
    %(parallel)s
    %(plotting)s
    plot_kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)

    if data_key != "obs":
        _check_collection(adata,
                          genes,
                          "var_names",
                          use_raw=kwargs.get("use_raw", False))
    else:
        _check_collection(adata,
                          genes,
                          "obs",
                          use_raw=kwargs.get("use_raw", False))

    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(map(lambda ln: ln is None,
                 lineages)):  # no lineage, all the weights are 1
        lineages = [None]
        show_cbar = False
        logg.debug("All lineages are `None`, setting the weights to `1`")
    lineages = _unique_order_preserving(lineages)

    if same_plot:
        gene_as_title = True if gene_as_title is None else gene_as_title
        sharex = "all" if sharex is None else sharex
        sharey = "none" if sharey is None else sharey
        ncols = len(genes) if ncols >= len(genes) else ncols
        nrows = int(np.ceil(len(genes) / ncols))
    else:
        gene_as_title = False if gene_as_title is None else gene_as_title
        sharex = "col" if sharex is None else sharex
        sharey = (
            "none" if hide_cells else "row") if sharey is None else sharey
        nrows = len(genes)
        ncols = len(lineages)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=sharex,
        sharey=sharey,
        figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize,
        constrained_layout=True,
    )
    axes = np.reshape(axes, (-1, ncols))

    _ = adata.obsm[ln_key][[lin for lin in lineages if lin is not None]]

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    kwargs["time_key"] = time_key
    kwargs["data_key"] = data_key
    kwargs["backward"] = backward
    callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs)

    kwargs["conf_int"] = conf_int  # prepare doesnt take or need this
    models = _create_models(model, genes, lineages)

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = time_key

    n_jobs = _get_n_cores(n_jobs, len(genes))
    backend = _get_backend(model, backend)

    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    models = parallelize(
        _fit_gene_trends,
        genes,
        unit="gene" if data_key != "obs" else "obs",
        backend=backend,
        n_jobs=n_jobs,
        extractor=lambda modelss:
        {k: v
         for m in modelss for k, v in m.items()},
        show_progress_bar=show_progres_bar,
    )(models, callbacks, lineages, time_range, **kwargs)
    logg.info("    Finish", time=start)

    logg.info("Plotting trends")

    cnt = 0
    for row in range(len(axes)):
        for col in range(len(axes[row])):
            if cnt >= len(genes):
                break
            gene = genes[cnt]

            _trends_helper(
                adata,
                models,
                gene=gene,
                lineage_names=lineages,
                ln_key=ln_key,
                same_plot=same_plot,
                hide_cells=hide_cells,
                perc=perc,
                lineage_cmap=lineage_cmap,
                abs_prob_cmap=abs_prob_cmap,
                cell_color=cell_color,
                alpha=cell_alpha,
                lineage_alpha=lineage_alpha,
                size=size,
                lw=lw,
                show_cbar=show_cbar,
                margins=margins,
                sharey=sharey,
                gene_as_title=gene_as_title,
                legend_loc=legend_loc,
                dpi=dpi,
                figsize=figsize,
                fig=fig,
                axes=axes[row, col] if same_plot else axes[cnt],
                show_ylabel=col == 0,
                show_lineage=cnt == 0 or same_plot,
                show_xticks_and_label=((row + 1) * ncols + col >= len(genes))
                if same_plot else (cnt == len(axes) - 1),
                **plot_kwargs,
            )
            cnt += 1

    if same_plot and (col != ncols):
        for ax in np.ravel(axes)[cnt:]:
            ax.remove()

    fig.suptitle(suptitle)

    if save is not None:
        save_fig(fig, save)