示例#1
0
def cluster_lineage(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineage: str,
    backward: bool = False,
    time_range: _time_range_type = None,
    clusters: Optional[Sequence[str]] = None,
    n_points: int = 200,
    time_key: str = "latent_time",
    norm: bool = True,
    recompute: bool = False,
    callback: _callback_type = None,
    ncols: int = 3,
    sharey: Union[str, bool] = False,
    key: Optional[str] = None,
    random_state: Optional[int] = None,
    use_leiden: bool = False,
    show_progress_bar: bool = True,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
    neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
    clustering_kwargs: Dict = MappingProxyType({}),
    return_models: bool = False,
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential
    lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when
        plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key
        Key in ``adata.uns`` where to save the results. If `None`, it will be saved as ``lineage_{lineage}_trend`` .
    random_state
        Random seed for reproducibility.
    use_leiden
        Whether to use :func:`scanpy.tl.leiden` for clustering or :func:`scanpy.tl.louvain`.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    clustering_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain` or :func:`scanpy.tl.leiden`.
    %(return_models)s
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s

        Also updates ``adata.uns`` with the following:

            - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of
              shape `(n_genes, n_points)` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    if key is None:
        key = f"lineage_{lineage}_trend"

    if recompute or key not in adata.uns:
        kwargs["backward"] = backward
        kwargs["time_key"] = time_key
        kwargs["n_test_points"] = n_points
        models = _create_models(model, genes, [lineage])
        all_models, models, genes, _ = _fit_bulk(
            models,
            _create_callbacks(adata, callback, genes, [lineage], **kwargs),
            genes,
            lineage,
            time_range,
            return_models=True,  # always return (better error messages)
            filter_all_failed=True,
            parallel_kwargs={
                "show_progress_bar": show_progress_bar,
                "n_jobs": _get_n_cores(n_jobs, len(genes)),
                "backend": _get_backend(models, backend),
            },
            **kwargs,
        )

        # `n_genes, n_test_points`
        trends = np.vstack(
            [model[lineage].y_test for model in models.values()]).T

        if norm:
            logg.debug("Normalizing trends")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        random_state = np.random.mtrand.RandomState(random_state).randint(
            2**16)

        pca_kwargs = dict(pca_kwargs)
        pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1)
        pca_kwargs.setdefault("random_state", random_state)
        sc.pp.pca(trends, **pca_kwargs)

        neighbors_kwargs = dict(neighbors_kwargs)
        neighbors_kwargs.setdefault("random_state", random_state)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        clustering_kwargs = dict(clustering_kwargs)
        clustering_kwargs["key_added"] = "clusters"
        clustering_kwargs.setdefault("random_state", random_state)
        try:
            if use_leiden:
                sc.tl.leiden(trends, **clustering_kwargs)
            else:
                sc.tl.louvain(trends, **clustering_kwargs)
        except ImportError as e:
            logg.warning(str(e))
            if use_leiden:
                sc.tl.louvain(trends, **clustering_kwargs)
            else:
                sc.tl.leiden(trends, **clustering_kwargs)

        logg.info(f"Saving data to `adata.uns[{key!r}]`")
        adata.uns[key] = trends
    else:
        all_models = None
        logg.info(f"Loading data from `adata.uns[{key!r}]`")
        trends = adata.uns[key]

    if "clusters" not in trends.obs:
        raise KeyError(
            "Unable to find the clustering in `trends.obs['clusters']`.")

    if clusters is None:
        clusters = trends.obs["clusters"].cat.categories
    for c in clusters:
        if c not in trends.obs["clusters"].cat.categories:
            raise ValueError(
                f"Invalid cluster name `{c!r}`. "
                f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`."
            )

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs["clusters"] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models
示例#2
0
def heatmap(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    mode: str = HeatmapMode.LINEAGES.s,
    time_key: str = "latent_time",
    time_range: Optional[Union[_time_range_type,
                               List[_time_range_type]]] = None,
    callback: _callback_type = None,
    cluster_key: Optional[Union[str, Sequence[str]]] = None,
    show_absorption_probabilities: bool = False,
    cluster_genes: bool = False,
    keep_gene_order: bool = False,
    scale: bool = True,
    n_convolve: Optional[int] = 5,
    show_all_genes: bool = False,
    cbar: bool = True,
    lineage_height: float = 0.33,
    fontsize: Optional[float] = None,
    xlabel: Optional[str] = None,
    cmap: mcolors.ListedColormap = cm.viridis,
    dendrogram: bool = True,
    return_genes: bool = False,
    return_models: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> Optional[Union[Dict[str, pd.DataFrame], Tuple[_return_model_type, Dict[
        str, pd.DataFrame]]]]:
    """
    Plot a heatmap of smoothed gene expression along specified lineages.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages for which to plot. If `None`, plot all lineages.
    %(backward)s
    mode
        Valid options are:

            - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``.
            - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    cluster_key
        Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap.
        Only available when ``mode={m.LINEAGES.s!r}``.
    show_absorption_probabilities
        Whether to also plot absorption probabilities alongside the smoothed expression.
        Only available when ``mode={m.LINEAGES.s!r}``.
    cluster_genes
        Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``.
    keep_gene_order
        Whether to keep the gene order for later lineages after the first was sorted.
        Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``.
    scale
        Whether to normalize the gene expression `0-1` range.
    n_convolve
        Size of the convolution window when smoothing absorption probabilities.
    show_all_genes
        Whether to show all genes on y-axis.
    cbar
        Whether to show the colorbar.
    lineage_height
        Height of a bar when ``mode={m.GENES.s!r}``.
    fontsize
        Size of the title's font.
    xlabel
        Label on the x-axis. If `None`, it is determined based on ``time_key``.
    cmap
        Colormap to use when visualizing the smoothed expression.
    dendrogram
        Whether to show dendrogram when ``cluster_genes=True``.
    return_genes
        Whether to return the sorted or clustered genes. Only available when ``mode={m.LINEAGES.s!r}``.
    %(return_models)s
    %(parallel)s
    %(plotting)s
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s
    :class:`pandas.DataFrame`
        If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame`
        containing the clustered or sorted genes.
    """

    import seaborn as sns

    def find_indices(series: pd.Series, values) -> Tuple[Any]:
        def find_nearest(array: np.ndarray, value: float) -> int:
            ix = np.searchsorted(array, value, side="left")
            if ix > 0 and (ix == len(array) or fabs(value - array[ix - 1]) <
                           fabs(value - array[ix])):
                return ix - 1
            return ix

        series = series[np.argsort(series.values)]

        return tuple(series[[find_nearest(series.values, v)
                             for v in values]].index)

    def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray:
        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        lin = adata[ixs, :].obsm[lineage_key][lname]

        lin = lin.X.copy().squeeze()
        if n_convolve is not None:
            lin = convolve(lin,
                           np.ones(n_convolve) / n_convolve,
                           mode="nearest")

        return lin

    def create_col_colors(lname: str,
                          rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]:
        color = adata.obsm[lineage_key][lname].colors[0]
        lin = subset_lineage(lname, rng)

        h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
        end_color = mcolors.hsv_to_rgb([h, 1, v])

        lineage_cmap = mcolors.LinearSegmentedColormap.from_list(
            "lineage_cmap", ["#ffffff", end_color], N=len(rng))
        norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin))
        scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm)

        return (
            np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]),
            lineage_cmap,
            norm,
        )

    def create_col_categorical_color(cluster_key: str,
                                     rng: np.ndarray) -> np.ndarray:
        if not is_categorical_dtype(adata.obs[cluster_key]):
            raise TypeError(
                f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                f"found `{adata.obs[cluster_key].dtype.name!r}`.")

        color_key = f"{cluster_key}_colors"
        if color_key not in adata.uns:
            logg.warning(
                f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors"
            )
            colors = _create_categorical_colors(
                len(adata.obs[cluster_key].cat.categories))
            adata.uns[color_key] = colors
        else:
            colors = adata.uns[color_key]

        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))

        return np.array([
            mcolors.to_hex(mapper[v])
            for v in adata[ixs, :].obs[cluster_key].values
        ])

    def create_cbar(
        ax,
        x_delta: float,
        cmap: Cmap,
        norm: Norm,
        label: Optional[str] = None,
    ) -> Ax:
        cax = inset_axes(
            ax,
            width="1%",
            height="100%",
            loc="lower right",
            bbox_to_anchor=(x_delta, 0, 1, 1),
            bbox_transform=ax.transAxes,
        )

        _ = mpl.colorbar.ColorbarBase(
            cax,
            cmap=cmap,
            norm=norm,
            label=label,
            ticks=np.linspace(norm.vmin, norm.vmax, 5),
        )

        return cax

    @valuedispatch
    def _plot_heatmap(_mode: HeatmapMode) -> Fig:
        pass

    @_plot_heatmap.register(HeatmapMode.GENES)
    def _() -> Tuple[Fig, None]:
        def color_fill_rec(ax,
                           xs,
                           y1,
                           y2,
                           colors=None,
                           cmap=cmap,
                           **kwargs) -> None:
            colors = colors if cmap is None else cmap(colors)

            x = 0
            for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)):
                dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2])
                ax.add_patch(
                    plt.Rectangle((x, y1),
                                  dx,
                                  y2 - y1,
                                  color=color,
                                  ec=color,
                                  **kwargs))

            ax.plot(x, y2, lw=0)

        fig, axes = plt.subplots(
            nrows=len(genes) + show_absorption_probabilities,
            figsize=(12, len(genes) + len(lineages) * lineage_height)
            if figsize is None else figsize,
            dpi=dpi,
            constrained_layout=True,
        )

        if not isinstance(axes, Iterable):
            axes = [axes]
        axes = np.ravel(axes)

        if show_absorption_probabilities:
            data["absorption probability"] = data[next(iter(data.keys()))]

        for ax, (gene, models) in zip(axes, data.items()):
            if scale:
                vmin, vmax = 0, 1
            else:
                c = np.array([m.y_test for m in models.values()])
                vmin, vmax = np.nanmin(c), np.nanmax(c)

            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

            ix = 0
            ys = [ix]

            if gene == "absorption probability":
                norm = mcolors.Normalize(vmin=0, vmax=1)
                for ln, x in ((ln, m.x_test) for ln, m in models.items()):
                    y = np.ones_like(x)
                    c = subset_lineage(ln, x.squeeze())

                    color_fill_rec(ax,
                                   x,
                                   y * ix,
                                   y * (ix + lineage_height),
                                   colors=norm(c))

                    ix += lineage_height
                    ys.append(ix)
            else:
                for x, c in ((m.x_test, m.y_test) for m in models.values()):
                    y = np.ones_like(x)
                    c = _min_max_scale(c) if scale else c

                    color_fill_rec(ax,
                                   x,
                                   y * ix,
                                   y * (ix + lineage_height),
                                   colors=norm(c))

                    ix += lineage_height
                    ys.append(ix)

            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.min(xs), np.max(xs)
            ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS))

            ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2)
            ax.spines["left"].set_position(
                ("data", 0)
            )  # move the left spine to the rectangles to get nicer yticks
            ax.set_yticklabels(models.keys(), ha="right")

            ax.set_title(gene, fontdict={"fontsize": fontsize})
            ax.set_ylabel("lineage")

            for pos in ["top", "bottom", "left", "right"]:
                ax.spines[pos].set_visible(False)

            if cbar:
                cax, _ = mpl.colorbar.make_axes(ax)
                _ = mpl.colorbar.ColorbarBase(
                    cax,
                    ticks=np.linspace(vmin, vmax, 5),
                    norm=norm,
                    cmap=cmap,
                    label="value" if gene == "absorption probability" else
                    "scaled expression" if scale else "expression",
                )

            ax.tick_params(
                top=False,
                bottom=False,
                left=True,
                right=False,
                labelleft=True,
                labelbottom=False,
            )

        ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f"))
        ax.tick_params(
            top=False,
            bottom=True,
            left=True,
            right=False,
            labelleft=True,
            labelbottom=True,
        )
        ax.set_xlabel(xlabel)

        return fig, None

    @_plot_heatmap.register(HeatmapMode.LINEAGES)
    def _() -> Tuple[List[Fig], pd.DataFrame]:
        data_t = defaultdict(dict)  # transpose
        for gene, lns in data.items():
            for ln, y in lns.items():
                data_t[ln][gene] = y

        figs = []
        gene_order = None
        sorted_genes = pd.DataFrame() if return_genes else None

        for lname, models in data_t.items():
            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.nanmin(xs), np.nanmax(xs)

            df = pd.DataFrame([m.y_test for m in models.values()],
                              index=models.keys())
            df.index.name = "genes"

            if not cluster_genes:
                if gene_order is not None:
                    df = df.loc[gene_order]
                else:
                    max_sort = np.argsort(
                        np.argmax(df.apply(_min_max_scale, axis=1).values,
                                  axis=1))
                    df = df.iloc[max_sort, :]
                    if keep_gene_order:
                        gene_order = df.index

            cat_colors = None
            if cluster_key is not None:
                cat_colors = np.stack(
                    [
                        create_col_categorical_color(
                            c, np.linspace(x_min, x_max, df.shape[1]))
                        for c in cluster_key
                    ],
                    axis=0,
                )

            if show_absorption_probabilities:
                col_colors, col_cmap, col_norm = create_col_colors(
                    lname, np.linspace(x_min, x_max, df.shape[1]))
                if cat_colors is not None:
                    col_colors = np.vstack([cat_colors, col_colors[None, :]])
            else:
                col_colors, col_cmap, col_norm = cat_colors, None, None

            row_cluster = cluster_genes and df.shape[0] > 1
            show_clust = row_cluster and dendrogram

            g = sns.clustermap(
                df,
                cmap=cmap,
                figsize=(10, min(len(genes) / 8 +
                                 1, 10)) if figsize is None else figsize,
                xticklabels=False,
                row_cluster=row_cluster,
                col_colors=col_colors,
                colors_ratio=0,
                col_cluster=False,
                cbar_pos=None,
                yticklabels=show_all_genes or "auto",
                standard_scale=0 if scale else None,
            )

            if cbar:
                cax = create_cbar(
                    g.ax_heatmap,
                    0.1,
                    cmap=cmap,
                    norm=mcolors.Normalize(
                        vmin=0 if scale else np.min(df.values),
                        vmax=1 if scale else np.max(df.values),
                    ),
                    label="scaled expression" if scale else "expression",
                )
                g.fig.add_axes(cax)

                if col_cmap is not None and col_norm is not None:
                    cax = create_cbar(
                        g.ax_heatmap,
                        0.25,
                        cmap=col_cmap,
                        norm=col_norm,
                        label="absorption probability",
                    )
                    g.fig.add_axes(cax)

            if g.ax_col_colors:
                main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap)
                n_bars = show_absorption_probabilities + (
                    len(cluster_key) if cluster_key is not None else 0)
                _set_ax_height_to_cm(
                    g.fig,
                    g.ax_col_colors,
                    height=min(
                        5,
                        max(n_bars * main_bbox.height / len(df),
                            0.25 * n_bars)),
                )
                g.ax_col_colors.set_title(lname,
                                          fontdict={"fontsize": fontsize})
            else:
                g.ax_heatmap.set_title(lname, fontdict={"fontsize": fontsize})

            g.ax_col_dendrogram.set_visible(
                False)  # gets rid of top free space

            g.ax_heatmap.yaxis.tick_left()
            g.ax_heatmap.yaxis.set_label_position("right")

            g.ax_heatmap.set_xlabel(xlabel)
            g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS))
            g.ax_heatmap.set_xticklabels(
                list(
                    map(lambda n: round(n, 3),
                        np.linspace(x_min, x_max, _N_XTICKS))))

            if show_clust:
                # robustly show dendrogram, because gene names can be long
                g.ax_row_dendrogram.set_visible(True)
                dendro_box = g.ax_row_dendrogram.get_position()

                pad = 0.005
                bb = g.ax_heatmap.yaxis.get_tightbbox(
                    g.fig.canvas.get_renderer()).transformed(
                        g.fig.transFigure.inverted())

                dendro_box.x0 = bb.x0 - dendro_box.width - pad
                dendro_box.x1 = bb.x0 - pad

                g.ax_row_dendrogram.set_position(dendro_box)
            else:
                g.ax_row_dendrogram.set_visible(False)

            if return_genes:
                sorted_genes[lname] = (df.index[g.dendrogram_row.reordered_ind]
                                       if hasattr(g, "dendrogram_row")
                                       and g.dendrogram_row is not None else
                                       df.index)

            figs.append(g)

        return figs, sorted_genes

    mode = HeatmapMode(mode)

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[lineage_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    lineages = _unique_order_preserving(lineages)

    _ = adata.obsm[lineage_key][lineages]

    if cluster_key is not None:
        if isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        cluster_key = _unique_order_preserving(cluster_key)

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    _check_collection(adata,
                      genes,
                      "var_names",
                      use_raw=kwargs.get("use_raw", False))

    kwargs["backward"] = backward
    kwargs["time_key"] = time_key
    models = _create_models(model, genes, lineages)
    all_models, data, genes, lineages = _fit_bulk(
        models,
        _create_callbacks(adata, callback, genes, lineages, **kwargs),
        genes,
        lineages,
        time_range,
        return_models=True,  # always return (better error messages)
        filter_all_failed=True,
        parallel_kwargs={
            "show_progress_bar": show_progress_bar,
            "n_jobs": _get_n_cores(n_jobs, len(genes)),
            "backend": _get_backend(models, backend),
        },
        **kwargs,
    )

    xlabel = time_key if xlabel is None else xlabel

    logg.debug(f"Plotting `{mode.s!r}` heatmap")
    fig, genes = _plot_heatmap(mode)

    if save is not None and fig is not None:
        if not isinstance(fig, Iterable):
            save_fig(fig, save)
        elif len(fig) == 1:
            save_fig(fig[0], save)
        else:
            for ln, f in zip(lineages, fig):
                save_fig(f, os.path.join(save, f"lineage_{ln}"))

    if return_genes and mode == HeatmapMode.LINEAGES:
        return (all_models, genes) if return_models else genes
    elif return_models:
        return all_models
示例#3
0
def cluster_lineage(
        adata: AnnData,
        model: _model_type,
        genes: Sequence[str],
        lineage: str,
        backward: bool = False,
        time_range: _time_range_type = None,
        clusters: Optional[Sequence[str]] = None,
        n_points: int = 200,
        time_key: str = "latent_time",
        cluster_key: str = "clusters",
        norm: bool = True,
        recompute: bool = False,
        callback: _callback_type = None,
        ncols: int = 3,
        sharey: Union[str, bool] = False,
        key_added: Optional[str] = None,
        show_progress_bar: bool = True,
        n_jobs: Optional[int] = 1,
        backend: str = _DEFAULT_BACKEND,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
        neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
        louvain_kwargs: Dict = MappingProxyType({}),
        **kwargs,
) -> None:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential lineage
    drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered.
        Useful when plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    cluster_key
        Key in ``adata.obs`` where the clustering is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key_added
        Postfix to add when saving the results to ``adata.uns``.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    louvain_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain`.
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s

        Updates ``adata.uns`` with the following key:

            - ``lineage_{lineage}_trend_{key_added}`` - an :class:`anndata.AnnData` object
              of shape ``(n_genes, n_points)`` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    key_to_add = f"lineage_{lineage}_trend"
    if key_added is not None:
        logg.debug(f"Adding key `{key_added!r}`")
        key_to_add += f"_{key_added}"

    if recompute or key_to_add not in adata.uns:
        kwargs["time_key"] = time_key  # kwargs for the model.prepare
        kwargs["n_test_points"] = n_points
        kwargs["backward"] = backward

        models = _create_models(model, genes, [lineage])
        callbacks = _create_callbacks(adata, callback, genes, [lineage],
                                      **kwargs)

        backend = _get_backend(model, backend)
        n_jobs = _get_n_cores(n_jobs, len(genes))

        start = logg.info(f"Computing gene trends using `{n_jobs}` core(s)")
        trends = parallelize(
            _cluster_lineages_helper,
            genes,
            as_array=True,
            unit="gene",
            n_jobs=n_jobs,
            backend=backend,
            extractor=np.vstack,
            show_progress_bar=show_progress_bar,
        )(models, callbacks, lineage, time_range, **kwargs)
        logg.info("    Finish", time=start)

        trends = trends.T
        if norm:
            logg.debug("Normalizing using `StandardScaler`")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if n_points is not None and trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        pca_kwargs = dict(pca_kwargs)
        n_comps = pca_kwargs.pop(
            "n_comps",
            min(50, kwargs.get("n_test_points"), len(genes)) -
            1)  # default value

        sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        louvain_kwargs = dict(louvain_kwargs)
        louvain_kwargs["key_added"] = cluster_key
        sc.tl.louvain(trends, **louvain_kwargs)

        adata.uns[key_to_add] = trends
    else:
        logg.info(f"Loading data from `adata.uns[{key_to_add!r}]`")
        trends = adata.uns[key_to_add]

    if clusters is None:
        if cluster_key not in trends.obs:
            raise KeyError(f"Invalid cluster key `{cluster_key!r}`.")
        clusters = trends.obs[cluster_key].cat.categories

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs[cluster_key] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)
示例#4
0
def gene_trends(
    adata: AnnData,
    model: _input_model_type,
    genes: Union[str, Sequence[str]],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    data_key: str = "X",
    time_key: str = "latent_time",
    transpose: bool = False,
    time_range: Optional[Union[_time_range_type,
                               List[_time_range_type]]] = None,
    callback: _callback_type = None,
    conf_int: Union[bool, float] = True,
    same_plot: bool = False,
    hide_cells: bool = False,
    perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float,
                                                             float]]]] = None,
    lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
    abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
    cell_color: Optional[str] = None,
    cell_alpha: float = 0.6,
    lineage_alpha: float = 0.2,
    size: float = 15,
    lw: float = 2,
    cbar: bool = True,
    margins: float = 0.015,
    sharex: Optional[Union[str, bool]] = None,
    sharey: Optional[Union[str, bool]] = None,
    gene_as_title: Optional[bool] = None,
    legend_loc: Optional[str] = "best",
    obs_legend_loc: Optional[str] = "best",
    ncols: int = 2,
    suptitle: Optional[str] = None,
    return_models: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    plot_kwargs: Mapping = MappingProxyType({}),
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression,
    where we take the lineage weights into account in the loss function.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages to plot. If `None`, plot all lineages.
    %(backward)s
    data_key
        Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    transpose
        If ``same_plot=True``, group the trends by ``lineages`` instead of ``genes``. This enforces ``hide_cells=True``.
        If ``same_plot=False``, show ``lineages`` in rows and ``genes`` in columns.
    %(model_callback)s
    conf_int
        Whether to compute and show confidence interval. If the :paramref:`model` is :class:`cellrank.ul.models.GAMR`,
        it can also specify the confidence level, the default is `0.95`.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all cells.
    perc
        Percentile for colors. Valid values are in interval `[0, 100]`.
        This can improve visualization. Can be specified individually for each lineage.
    lineage_cmap
        Categorical colormap to use when coloring in the lineages. If `None` and ``same_plot``,
        use the corresponding colors in ``adata.uns``, otherwise use `'black'`.
    abs_prob_cmap
        Continuous colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when ``same_plot=False``.
    cell_color
        Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names` used for coloring the cells.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``.
    margins
        Margins around the plot.
    sharex
        Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`.
    sharey
        Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`.
    gene_as_title
        Whether to show gene names as titles instead on y-axis.
    legend_loc
        Location of the legend displaying lineages. Only used when `same_plot=True`.
    obs_legend_loc
        Location of the legend when ``cell_color`` corresponds to a categorical variable.
    ncols
        Number of columns of the plot when plotting multiple genes. Only used when ``same_plot=True``.
    suptitle
        Suptitle of the figure.
    %(return_models)s
    %(parallel)s
    %(plotting)s
    plot_kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    if data_key != "obs":
        _check_collection(adata,
                          genes,
                          "var_names",
                          use_raw=kwargs.get("use_raw", False))
    else:
        _check_collection(adata,
                          genes,
                          "obs",
                          use_raw=kwargs.get("use_raw", False))

    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(ln is None
             for ln in lineages):  # no lineage, all the weights are 1
        lineages = [None]
        cbar = False
        logg.debug("All lineages are `None`, setting the weights to `1`")
    lineages = _unique_order_preserving(lineages)

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    kwargs["time_key"] = time_key
    kwargs["data_key"] = data_key
    kwargs["backward"] = backward
    kwargs["conf_int"] = conf_int  # prepare doesnt take or need this
    models = _create_models(model, genes, lineages)

    all_models, models, genes, lineages = _fit_bulk(
        models,
        _create_callbacks(adata, callback, genes, lineages, **kwargs),
        genes,
        lineages,
        time_range,
        return_models=True,
        filter_all_failed=False,
        parallel_kwargs={
            "show_progress_bar": show_progress_bar,
            "n_jobs": _get_n_cores(n_jobs, len(genes)),
            "backend": _get_backend(models, backend),
        },
        **kwargs,
    )

    lineages = sorted(lineages)
    tmp = adata.obsm[ln_key][lineages].colors
    if lineage_cmap is None and not transpose:
        lineage_cmap = tmp

    plot_kwargs = dict(plot_kwargs)
    plot_kwargs["obs_legend_loc"] = obs_legend_loc
    if transpose:
        all_models = pd.DataFrame(all_models).T.to_dict()
        models = pd.DataFrame(models).T.to_dict()
        genes, lineages = lineages, genes
        hide_cells = same_plot or hide_cells
    else:
        # information overload otherwise
        plot_kwargs["lineage_probability"] = False
        plot_kwargs["lineage_probability_conf_int"] = False

    tmp = pd.DataFrame(models).T.astype(bool)
    start_rows = np.argmax(tmp.values, axis=0)
    end_rows = tmp.shape[0] - np.argmax(tmp[::-1].values, axis=0) - 1

    if same_plot:
        gene_as_title = True if gene_as_title is None else gene_as_title
        sharex = "all" if sharex is None else sharex
        if sharey is None:
            sharey = "row" if plot_kwargs.get("lineage_probability",
                                              False) else "none"
        ncols = len(genes) if ncols >= len(genes) else ncols
        nrows = int(np.ceil(len(genes) / ncols))
    else:
        gene_as_title = False if gene_as_title is None else gene_as_title
        sharex = "col" if sharex is None else sharex
        if sharey is None:
            sharey = ("row" if not hide_cells or plot_kwargs.get(
                "lineage_probability", False) else "none")
        nrows = len(genes)
        ncols = len(lineages)

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = time_key

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=sharex,
        sharey=sharey,
        figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize,
        tight_layout=True,
        dpi=dpi,
    )
    axes = np.reshape(axes, (nrows, ncols))

    cnt = 0
    plot_kwargs["obs_legend_loc"] = None if same_plot else obs_legend_loc

    logg.info("Plotting trends")
    for row in range(len(axes)):
        for col in range(len(axes[row])):
            if cnt >= len(genes):
                break
            gene = genes[cnt]
            if (same_plot and plot_kwargs.get("lineage_probability", False)
                    and transpose):
                lpc = adata.obsm[ln_key][gene].colors[0]
            else:
                lpc = None

            if same_plot:
                plot_kwargs["obs_legend_loc"] = (obs_legend_loc if row == 0
                                                 and col == len(axes[0]) - 1
                                                 else None)

            _trends_helper(
                models,
                gene=gene,
                lineage_names=lineages,
                transpose=transpose,
                same_plot=same_plot,
                hide_cells=hide_cells,
                perc=perc,
                lineage_cmap=lineage_cmap,
                abs_prob_cmap=abs_prob_cmap,
                lineage_probability_color=lpc,
                cell_color=cell_color,
                alpha=cell_alpha,
                lineage_alpha=lineage_alpha,
                size=size,
                lw=lw,
                cbar=cbar,
                margins=margins,
                sharey=sharey,
                gene_as_title=gene_as_title,
                legend_loc=legend_loc,
                figsize=figsize,
                fig=fig,
                axes=axes[row, col] if same_plot else axes[cnt],
                show_ylabel=col == 0,
                show_lineage=same_plot or (cnt == start_rows),
                show_xticks_and_label=((row + 1) * ncols + col >= len(genes))
                if same_plot else (cnt == end_rows),
                **plot_kwargs,
            )
            # plot legend on the 1st plot
            cnt += 1

            if not same_plot:
                plot_kwargs["obs_legend_loc"] = None

    if same_plot and (col != ncols):
        for ax in np.ravel(axes)[cnt:]:
            ax.remove()

    fig.suptitle(suptitle, y=1.05)

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models
示例#5
0
def gene_trends(
        adata: AnnData,
        model: _model_type,
        genes: Union[str, Sequence[str]],
        lineages: Optional[Union[str, Sequence[str]]] = None,
        backward: bool = False,
        data_key: str = "X",
        time_key: str = "latent_time",
        time_range: Optional[Union[_time_range_type,
                                   List[_time_range_type]]] = None,
        callback: _callback_type = None,
        conf_int: bool = True,
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Optional[Union[Tuple[float, float],
                             Sequence[Tuple[float, float]]]] = None,
        lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
        abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        cell_alpha: float = 0.6,
        lineage_alpha: float = 0.2,
        size: float = 15,
        lw: float = 2,
        show_cbar: bool = True,
        margins: float = 0.015,
        sharex: Optional[Union[str, bool]] = None,
        sharey: Optional[Union[str, bool]] = None,
        gene_as_title: Optional[bool] = None,
        legend_loc: Optional[str] = "best",
        ncols: int = 2,
        suptitle: Optional[str] = None,
        n_jobs: Optional[int] = 1,
        backend: str = _DEFAULT_BACKEND,
        show_progres_bar: bool = True,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        plot_kwargs: Mapping = MappingProxyType({}),
        **kwargs,
) -> None:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression,
    where we take the lineage weights into account in the loss function.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages to plot. If `None`, plot all lineages.
    %(backward)s
    data_key
        Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    conf_int
        Whether to compute and show confidence intervals.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all cells.
    perc
        Percentile for colors. Valid values are in interval `[0, 100]`.
        This can improve visualization. Can be specified individually for each lineage.
    lineage_cmap
        Colormap to use when coloring in the lineages. If `None` and ``same_plot``, use the corresponding colors
        in ``adata.uns``, otherwise use `'black'`.
    abs_prob_cmap
        Colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when ``same_plot=False``.
    cell_color
        Color of the cells when not visualizing absorption probabilities. Only used when ``same_plot=True``.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    show_cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``.
    margins
        Margins around the plot.
    sharex
        Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`.
    sharey
        Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`.
    gene_as_title
        Whether to show gene names as titles instead on y-axis.
    legend_loc
        Location of the legend displaying lineages. Only used when `same_plot=True`.
    ncols
        Number of columns of the plot when pl multiple genes. Only used when ``same_plot=True``.
    suptitle
        Suptitle of the figure.
    %(parallel)s
    %(plotting)s
    plot_kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)

    if data_key != "obs":
        _check_collection(adata,
                          genes,
                          "var_names",
                          use_raw=kwargs.get("use_raw", False))
    else:
        _check_collection(adata,
                          genes,
                          "obs",
                          use_raw=kwargs.get("use_raw", False))

    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(map(lambda ln: ln is None,
                 lineages)):  # no lineage, all the weights are 1
        lineages = [None]
        show_cbar = False
        logg.debug("All lineages are `None`, setting the weights to `1`")
    lineages = _unique_order_preserving(lineages)

    if same_plot:
        gene_as_title = True if gene_as_title is None else gene_as_title
        sharex = "all" if sharex is None else sharex
        sharey = "none" if sharey is None else sharey
        ncols = len(genes) if ncols >= len(genes) else ncols
        nrows = int(np.ceil(len(genes) / ncols))
    else:
        gene_as_title = False if gene_as_title is None else gene_as_title
        sharex = "col" if sharex is None else sharex
        sharey = (
            "none" if hide_cells else "row") if sharey is None else sharey
        nrows = len(genes)
        ncols = len(lineages)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=sharex,
        sharey=sharey,
        figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize,
        constrained_layout=True,
    )
    axes = np.reshape(axes, (-1, ncols))

    _ = adata.obsm[ln_key][[lin for lin in lineages if lin is not None]]

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    kwargs["time_key"] = time_key
    kwargs["data_key"] = data_key
    kwargs["backward"] = backward
    callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs)

    kwargs["conf_int"] = conf_int  # prepare doesnt take or need this
    models = _create_models(model, genes, lineages)

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = time_key

    n_jobs = _get_n_cores(n_jobs, len(genes))
    backend = _get_backend(model, backend)

    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    models = parallelize(
        _fit_gene_trends,
        genes,
        unit="gene" if data_key != "obs" else "obs",
        backend=backend,
        n_jobs=n_jobs,
        extractor=lambda modelss:
        {k: v
         for m in modelss for k, v in m.items()},
        show_progress_bar=show_progres_bar,
    )(models, callbacks, lineages, time_range, **kwargs)
    logg.info("    Finish", time=start)

    logg.info("Plotting trends")

    cnt = 0
    for row in range(len(axes)):
        for col in range(len(axes[row])):
            if cnt >= len(genes):
                break
            gene = genes[cnt]

            _trends_helper(
                adata,
                models,
                gene=gene,
                lineage_names=lineages,
                ln_key=ln_key,
                same_plot=same_plot,
                hide_cells=hide_cells,
                perc=perc,
                lineage_cmap=lineage_cmap,
                abs_prob_cmap=abs_prob_cmap,
                cell_color=cell_color,
                alpha=cell_alpha,
                lineage_alpha=lineage_alpha,
                size=size,
                lw=lw,
                show_cbar=show_cbar,
                margins=margins,
                sharey=sharey,
                gene_as_title=gene_as_title,
                legend_loc=legend_loc,
                dpi=dpi,
                figsize=figsize,
                fig=fig,
                axes=axes[row, col] if same_plot else axes[cnt],
                show_ylabel=col == 0,
                show_lineage=cnt == 0 or same_plot,
                show_xticks_and_label=((row + 1) * ncols + col >= len(genes))
                if same_plot else (cnt == len(axes) - 1),
                **plot_kwargs,
            )
            cnt += 1

    if same_plot and (col != ncols):
        for ax in np.ravel(axes)[cnt:]:
            ax.remove()

    fig.suptitle(suptitle)

    if save is not None:
        save_fig(fig, save)