예제 #1
0
def _run_in_parallel(fn: Callable, conn: csr_matrix, **kwargs) -> Any:
    fname = fn.__name__
    if fname == "_run_stochastic":
        if not _HAS_JAX:
            raise RuntimeError(
                "Install `jax` and `jaxlib` as `pip install jax jaxlib`.")
        ixs = np.argsort(np.array((conn != 0).sum(1)).ravel())[::-1]
    else:
        ixs = np.arange(conn.shape[0])
        np.random.shuffle(ixs)

    unit = ("sample" if (fname == "_run_mc") and kwargs.get("n_samples", 1) > 1
            else "cell")
    kwargs["indices"] = conn.indices
    kwargs["indptr"] = conn.indptr

    return parallelize(
        fn,
        ixs,
        as_array=False,
        extractor=lambda res: _reconstruct_one(np.concatenate(res, axis=-1),
                                               conn, ixs),
        unit=unit,
        **_filter_kwargs(parallelize, **kwargs),
    )(**_filter_kwargs(fn, **kwargs))
예제 #2
0
def _(
    x: Union[np.ndarray, spmatrix],
    library_size: np.ndarray,
    ref_ix: Optional[int] = None,
    **kwargs,
) -> np.ndarray:
    if ref_ix is None:
        f75 = _calc_factor_quant(x, library_size=library_size, p=0.75)
        ref_ix = np.argmin(np.abs(f75 - np.mean(f75)))

    ref = x[ref_ix]
    if issparse(ref):
        ref = ref.A.squeeze(0)  # (genes,)

    return parallelize(
        _calc_factor_weighted,
        collection=np.arange(x.shape[0]),
        show_progress_bar=False,
        as_array=False,
        extractor=np.concatenate,
        backend="threading",
        n_jobs=4,
    )(
        obs_=x,
        obs_lib_size_=library_size,
        ref=ref,
        ref_lib_size=library_size[ref_ix],
        **kwargs,
    )
예제 #3
0
    def simulate_many(
        self,
        n_sims: int,
        max_iter: Union[int, float] = 0.25,
        seed: Optional[int] = None,
        successive_hits: int = 0,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
    ) -> List[np.ndarray]:
        """
        Simulate many random walks.

        Parameters
        ----------
        n_sims
            Number of random walks to simulate.
        %(rw_sim.params)s
        %(parallel)s

        Returns
        -------
        List of arrays of shape ``(max_iter + 1,)`` of states that have been visited. If ``stop_ixs`` was specified,
        the arrays may have smaller shape.
        """
        if n_sims <= 0:
            raise ValueError(
                f"Expected number of simulations to be positive, found `{n_sims}`."
            )
        max_iter = self._max_iter(max_iter)
        start = logg.info(
            f"Simulating `{n_sims}` random walks of maximum length `{max_iter}`"
        )

        simss = parallelize(
            self._simulate_many,
            collection=np.arange(n_sims),
            n_jobs=n_jobs,
            backend=backend,
            show_progress_bar=show_progress_bar,
            as_array=False,
            unit="sim",
        )(max_iter=max_iter, seed=seed, successive_hits=successive_hits)
        simss = list(chain.from_iterable(simss))

        logg.info("    Finish", time=start)

        return simss
예제 #4
0
    def test_more_jobs_than_work(self, n_jobs: int):
        def callback(data, **_: Any):
            assert isinstance(data, csr_matrix)
            assert data.shape[1] == 100

            return [42] * data.shape[0]

        res = parallelize(
            callback,
            collection=srand(3, 100, format="csr"),
            n_jobs=n_jobs,
            show_progress_bar=False,
            extractor=np.concatenate,
        )()

        np.testing.assert_array_equal(res, 42)
예제 #5
0
    def bias_knn(
        self,
        conn: csr_matrix,
        pseudotime: np.ndarray,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
        **kwargs: Any,
    ) -> csr_matrix:
        """
        Bias cell-cell connectivities of a KNN graph.

        Parameters
        ----------
        conn
            Sparse matrix of shape ``(n_cells, n_cells)`` containing the nearest neighbor connectivities.
        pseudotime
            Pseudotemporal ordering of cells.
        %(parallel)s

        Returns
        -------
        The biased connectivities.
        """
        res = parallelize(
            self._bias_knn_helper,
            np.arange(conn.shape[0]),
            as_array=False,
            unit="cell",
            n_jobs=n_jobs,
            backend=backend,
            show_progress_bar=show_progress_bar,
        )(conn, pseudotime, **kwargs)
        data, indices, indptr = zip(*res)

        conn = csr_matrix((np.concatenate(data), np.concatenate(indices),
                           np.concatenate(indptr)))
        conn.eliminate_zeros()

        return conn
예제 #6
0
def _solve_lin_system(
    mat_a: Union[np.ndarray, spmatrix],
    mat_b: Union[np.ndarray, spmatrix],
    solver: str = _DEFAULT_SOLVER,
    use_petsc: bool = False,
    preconditioner: Optional[str] = None,
    n_jobs: Optional[int] = None,
    backend: str = _DEFAULT_BACKEND,
    tol: float = 1e-5,
    use_eye: bool = False,
    show_progress_bar: bool = True,
) -> np.ndarray:
    """
    Solve ``mat_a * x = mat_b`` efficiently using either iterative or direct methods.

    This is a utility function which is optimized for the case of ``mat_a`` and ``mat_b`` being sparse,
    and columns in ``mat_b`` being related. In that case, we can treat each column of ``mat_b`` as a
    separate linear problem and solve that efficiently using iterative solvers that exploit sparsity.

    If the columns of ``mat_b`` are related, we can use the solution of the previous problem as an
    initial guess for the next problem. Further, we parallelize the individual problems for each
    column in ``mat_b`` and solve them on separate kernels.

    In case ``mat_a`` is either not sparse, or very small, or ``mat_b`` has very many columns, it makes
    sense to use a direct solver instead which computes a matrix factorization and thereby solves all
    sub-problems at the same time.

    Parameters
    ----------
    mat_a
        Matrix of shape `n x n`. We make no assumptions on ``mat_a`` being symmetric or positive definite.
    mat_b
        Matrix of shape `n x m`, with m << n.
    solver
        Solver to use for the linear problem. Options are `'direct', 'gmres', 'lgmres', 'bicgstab' or 'gcrotmk'`
        when ``use_petsc`` or one of `petsc4py.PETSc.KPS.Type` otherwise.

        Information on the :mod:`scipy` iterative solvers can be found in :func:`scipy.sparse.linalg` or
        for the :mod:`petsc4py` solver in https://www.mcs.anl.gov/petsc/documentation/linearsolvertable.html.
    use_petsc
        Whether to use solvers from :mod:`petsc4py` instead of :mod:`scipy`. Recommended for large problems.
    preconditioner
        Preconditioner to use when ``use_petsc=True``. For available preconditioners, see `petsc4py.PETSc.PC.Type`.
    n_jobs
        Number of parallel jobs to use when ``use_petsc=True``. For small, quickly-solvable problems,
        we recommend high number (>=8) of cores in order to fully saturate them.
    backend
        Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options.
    tol
        The relative convergence tolerance, relative decrease in the (possibly preconditioned) residual norm .
    use_eye
        Solve ``(I - mat_a) * x = mat_b`` instead.
    show_progress_bar
        Whether to show progress bar when the solver isn't a direct one.

    Returns
    --------
    :class:`numpy.ndarray`
        Matrix of shape `n x m`. Each column corresponds to the solution of one of the sub-problems
        defined via columns in ``mat_b``.
    """
    def extractor(
            res_converged: List[Tuple[np.ndarray,
                                      int]]) -> Tuple[np.ndarray, int]:
        res, converged = zip(*res_converged)
        return np.hstack(res), sum(converged)

    n_jobs = _get_n_cores(n_jobs, n_jobs=None)

    if use_petsc:
        try:
            from petsc4py import PETSc  # noqa
        except ImportError:
            global _PETSC_ERROR_MSG_SHOWN
            if not _PETSC_ERROR_MSG_SHOWN:
                _PETSC_ERROR_MSG_SHOWN = True
                logg.warning(_PETSC_ERROR_MSG.format(_DEFAULT_SOLVER))
            solver = _DEFAULT_SOLVER
            use_petsc = False

    if use_eye:
        mat_a = (speye(mat_a.shape[0])
                 if issparse(mat_a) else np.eye(mat_a.shape[0])) - mat_a

    if solver == "direct":
        if use_petsc:
            logg.debug("Solving the linear system directly using `PETSc`")
            return _petsc_mat_solve(mat_a,
                                    mat_b,
                                    solver=solver,
                                    preconditioner=preconditioner,
                                    tol=tol)

        if issparse(mat_a):
            logg.debug("Densifying `A` for `scipy` direct solver")
            mat_a = mat_a.toarray()
        if issparse(mat_b):
            logg.debug("Densifying `B` for `scipy` direct solver")
            mat_b = mat_b.toarray()

        logg.debug("Solving the linear system directly using `scipy`")

        return solve(mat_a, mat_b)

    if use_petsc:
        if not isspmatrix_csr(mat_a):
            mat_a = csr_matrix(mat_a)

        mat_b = mat_b.T
        if not isspmatrix_csc(mat_b):
            mat_b = csc_matrix(mat_b)

        # as_array causes an issue, because it's called like this np.array([(NxM), (NxK), ....]
        # in the end, we want array of shape Nx(M + K + ...) - this is ensured by the extractor
        logg.debug(
            f"Solving the linear system using `PETSc` solver `{('gmres' if solver is None else solver)!r}` "
            f"on `{n_jobs}` core(s) with {'no' if preconditioner is None else preconditioner} preconditioner and "
            f"`tol={tol}`")

        # can't pass PETSc matrix - not pickleable
        mat_x, n_converged = parallelize(
            _solve_many_sparse_problems_petsc,
            mat_b,
            n_jobs=n_jobs,
            backend=backend,
            as_array=False,
            extractor=extractor,
            show_progress_bar=show_progress_bar,
        )(mat_a, solver=solver, preconditioner=preconditioner, tol=tol)
    elif solver in _AVAIL_ITER_SOLVERS:
        if not issparse(mat_a):
            logg.debug("Sparsifying `A` for iterative solver")
            mat_a = csr_matrix(mat_a)

        mat_b = mat_b.T
        if not issparse(mat_b):
            logg.debug("Sparsifying `B` for iterative solver")
            mat_b = csr_matrix(mat_b)

        logg.debug(
            f"Solving the linear system using `scipy` solver `{solver!r}` on `{n_jobs} cores(s)` with `tol={tol}`"
        )

        mat_x, n_converged = parallelize(
            _solve_many_sparse_problems,
            mat_b,
            n_jobs=n_jobs,
            backend=backend,
            as_array=False,
            extractor=extractor,
            show_progress_bar=show_progress_bar,
        )(mat_a, solver=_AVAIL_ITER_SOLVERS[solver], tol=tol)

    else:
        raise ValueError(f"Invalid solver `{solver!r}`.")

    if n_converged != mat_b.shape[0]:
        logg.warning(
            f"`{mat_b.shape[0] - n_converged}` solution(s) did not converge")

    return mat_x
예제 #7
0
def heatmap(
    adata: AnnData,
    model: _model_type,
    genes: Sequence[str],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    mode: str = HeatmapMode.LINEAGES.s,
    time_key: str = "latent_time",
    time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None,
    callback: _callback_type = None,
    cluster_key: Optional[Union[str, Sequence[str]]] = None,
    show_absorption_probabilities: bool = False,
    cluster_genes: bool = False,
    keep_gene_order: bool = False,
    scale: bool = True,
    n_convolve: Optional[int] = 5,
    show_all_genes: bool = False,
    show_cbar: bool = True,
    lineage_height: float = 0.33,
    fontsize: Optional[float] = None,
    xlabel: Optional[str] = None,
    cmap: mcolors.ListedColormap = cm.viridis,
    show_dendrogram: bool = True,
    return_genes: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    ext: str = "png",
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> Optional[Dict[str, pd.DataFrame]]:
    """
    Plot a heatmap of smoothed gene expression along specified lineages.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages for which to plot. If `None`, plot all lineages.
    %(backward)s
    mode
        Valid options are:

            - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``.
            - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    cluster_key
        Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap.
        Only available when ``mode={m.LINEAGES.s!r}``.
    show_absorption_probabilities
        Whether to also plot absorption probabilities alongside the smoothed expression.
        Only available when ``mode={m.LINEAGES.s!r}``.
    cluster_genes
        Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``.
    keep_gene_order
        Whether to keep the gene order for later lineages after the first was sorted.
        Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``.
    scale
        Whether to normalize the gene expression `0-1` range.
    n_convolve
        Size of the convolution window when smoothing absorption probabilities.
    show_all_genes
        Whether to show all genes on y-axis.
    show_cbar
        Whether to show the colorbar.
    lineage_height
        Height of a bar when ``mode={m.GENES.s!r}``.
    fontsize
        Size of the title's font.
    xlabel
        Label on the x-axis. If `None`, it is determined based on ``time_key``.
    cmap
        Colormap to use when visualizing the smoothed expression.
    show_dendrogram
        Whether to show dendrogram when ``cluster_genes=True``.
    return_genes
        Whether to return the sorted or clustered genes.
        Only available when ``mode={m.LINEAGES.s!r}``.
    %(parallel)s
    %(plotting)s
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s
    :class:`pandas.DataFrame`
        If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame`
        containing the clustered or sorted genes.
    """

    import seaborn as sns

    def find_indices(series: pd.Series, values) -> Tuple[Any]:
        def find_nearest(array: np.ndarray, value: float) -> int:
            ix = np.searchsorted(array, value, side="left")
            if ix > 0 and (
                ix == len(array)
                or fabs(value - array[ix - 1]) < fabs(value - array[ix])
            ):
                return ix - 1
            return ix

        series = series[np.argsort(series.values)]

        return tuple(series[[find_nearest(series.values, v) for v in values]].index)

    def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray:
        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        lin = adata[ixs, :].obsm[lineage_key][lname]

        lin = lin.X.copy().squeeze()
        if n_convolve is not None:
            lin = convolve(lin, np.ones(n_convolve) / n_convolve, mode="nearest")

        return lin

    def create_col_colors(lname: str, rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]:
        color = adata.obsm[lineage_key][lname].colors[0]
        lin = subset_lineage(lname, rng)

        h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
        end_color = mcolors.hsv_to_rgb([h, 1, v])

        lineage_cmap = mcolors.LinearSegmentedColormap.from_list(
            "lineage_cmap", ["#ffffff", end_color], N=len(rng)
        )
        norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin))
        scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm)

        return (
            np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]),
            lineage_cmap,
            norm,
        )

    def create_col_categorical_color(cluster_key: str, rng: np.ndarray) -> np.ndarray:
        if not is_categorical_dtype(adata.obs[cluster_key]):
            raise TypeError(
                f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                f"found `{adata.obs[cluster_key].dtype.name!r}`."
            )

        color_key = f"{cluster_key}_colors"
        if color_key not in adata.uns:
            logg.warning(
                f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors"
            )
            colors = _create_categorical_colors(
                len(adata.obs[cluster_key].cat.categories)
            )
            adata.uns[color_key] = colors
        else:
            colors = adata.uns[color_key]

        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))

        return np.array(
            [mcolors.to_hex(mapper[v]) for v in adata[ixs, :].obs[cluster_key].values]
        )

    def create_cbar(ax, x_delta: float, cmap, norm, label=None) -> Ax:
        cax = inset_axes(
            ax,
            width="1%",
            height="100%",
            loc="lower right",
            bbox_to_anchor=(x_delta, 0, 1, 1),
            bbox_transform=ax.transAxes,
        )

        _ = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, label=label)

        return cax

    @valuedispatch
    def _plot_heatmap(_mode: HeatmapMode) -> Fig:
        pass

    @_plot_heatmap.register(HeatmapMode.GENES)
    def _() -> Tuple[Fig, None]:
        def color_fill_rec(ax, xs, y1, y2, colors=None, cmap=cmap, **kwargs) -> None:
            colors = colors if cmap is None else cmap(colors)

            x = 0
            for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)):
                dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2])
                ax.add_patch(
                    plt.Rectangle((x, y1), dx, y2 - y1, color=color, ec=color, **kwargs)
                )

            ax.plot(x, y2, lw=0)

        fig, axes = plt.subplots(
            nrows=len(genes) + show_absorption_probabilities,
            figsize=(12, len(genes) + len(lineages) * lineage_height)
            if figsize is None
            else figsize,
            dpi=dpi,
            constrained_layout=True,
        )

        if not isinstance(axes, Iterable):
            axes = [axes]
        axes = np.ravel(axes)

        if show_absorption_probabilities:
            data["absorption probability"] = data[next(iter(data.keys()))]

        for ax, (gene, models) in zip(axes, data.items()):
            if scale:
                norm = mcolors.Normalize(vmin=0, vmax=1)
            else:
                c = np.array([m.y_test for m in models.values()])
                c_min, c_max = np.nanmin(c), np.nanmax(c)
                norm = mcolors.Normalize(vmin=c_min, vmax=c_max)

            ix = 0
            ys = [ix]

            if gene == "absorption probability":
                norm = mcolors.Normalize(vmin=0, vmax=1)
                for ln, x in ((ln, m.x_test) for ln, m in models.items()):
                    y = np.ones_like(x)
                    c = subset_lineage(ln, x.squeeze())

                    color_fill_rec(
                        ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)
                    )

                    ix += lineage_height
                    ys.append(ix)
            else:
                for x, c in ((m.x_test, m.y_test) for m in models.values()):
                    y = np.ones_like(x)
                    c = _min_max_scale(c) if scale else c

                    color_fill_rec(
                        ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)
                    )

                    ix += lineage_height
                    ys.append(ix)

            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.min(xs), np.max(xs)
            ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS))

            ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2)
            ax.spines["left"].set_position(
                ("data", 0)
            )  # move the left spine to the rectangles to get nicer yticks
            ax.set_yticklabels(lineages, ha="right")

            ax.set_title(gene, fontdict=dict(fontsize=fontsize))
            ax.set_ylabel("lineage")

            for pos in ["top", "bottom", "left", "right"]:
                ax.spines[pos].set_visible(False)

            cax, _ = mpl.colorbar.make_axes(ax)
            _ = mpl.colorbar.ColorbarBase(
                cax,
                norm=norm,
                cmap=cmap,
                label="value" if gene == "absorption probability" else "expression",
            )

            ax.tick_params(
                top=False,
                bottom=False,
                left=True,
                right=False,
                labelleft=True,
                labelbottom=False,
            )

        ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f"))
        ax.tick_params(
            top=False,
            bottom=True,
            left=True,
            right=False,
            labelleft=True,
            labelbottom=True,
        )
        ax.set_xlabel(xlabel)

        return fig, None

    @_plot_heatmap.register(HeatmapMode.LINEAGES)
    def _() -> Tuple[List[Fig], pd.DataFrame]:
        data_t = defaultdict(dict)  # transpose
        for gene, lns in data.items():
            for ln, y in lns.items():
                data_t[ln][gene] = y

        figs = []
        gene_order = None
        sorted_genes = pd.DataFrame() if return_genes else None

        for lname, models in data_t.items():
            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.nanmin(xs), np.nanmax(xs)

            df = pd.DataFrame([m.y_test for m in models.values()], index=genes)
            df.index.name = "genes"

            if not cluster_genes:
                if gene_order is not None:
                    df = df.loc[gene_order]
                else:
                    max_sort = np.argsort(
                        np.argmax(df.apply(_min_max_scale, axis=1).values, axis=1)
                    )
                    df = df.iloc[max_sort, :]
                    if keep_gene_order:
                        gene_order = df.index

            cat_colors = None
            if cluster_key is not None:
                cat_colors = np.stack(
                    [
                        create_col_categorical_color(
                            c, np.linspace(x_min, x_max, df.shape[1])
                        )
                        for c in cluster_key
                    ],
                    axis=0,
                )

            if show_absorption_probabilities:
                col_colors, col_cmap, col_norm = create_col_colors(
                    lname, np.linspace(x_min, x_max, df.shape[1])
                )
                if cat_colors is not None:
                    col_colors = np.vstack([cat_colors, col_colors[None, :]])
            else:
                col_colors, col_cmap, col_norm = cat_colors, None, None

            row_cluster = cluster_genes and df.shape[0] > 1
            show_clust = row_cluster and show_dendrogram

            g = sns.clustermap(
                df,
                cmap=cmap,
                figsize=(10, min(len(genes) / 8 + 1, 10))
                if figsize is None
                else figsize,
                xticklabels=False,
                cbar_kws={"label": "expression"},
                row_cluster=cluster_genes and df.shape[0] > 1,
                col_colors=col_colors,
                colors_ratio=0,
                col_cluster=False,
                cbar_pos=None,
                yticklabels=show_all_genes or "auto",
                standard_scale=0 if scale else None,
            )

            if show_cbar:
                cax = create_cbar(
                    g.ax_heatmap,
                    0.1,
                    cmap=cmap,
                    norm=mcolors.Normalize(
                        vmin=0 if scale else np.min(df.values),
                        vmax=1 if scale else np.max(df.values),
                    ),
                    label="expression",
                )
                g.fig.add_axes(cax)

                if col_cmap is not None and col_norm is not None:
                    cax = create_cbar(
                        g.ax_heatmap,
                        0.25,
                        cmap=col_cmap,
                        norm=col_norm,
                        label="absorption probability",
                    )
                    g.fig.add_axes(cax)

            if g.ax_col_colors:
                main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap)
                n_bars = show_absorption_probabilities + (
                    len(cluster_key) if cluster_key is not None else 0
                )
                _set_ax_height_to_cm(
                    g.fig,
                    g.ax_col_colors,
                    height=min(
                        5, max(n_bars * main_bbox.height / len(df), 0.25 * n_bars)
                    ),
                )
                g.ax_col_colors.set_title(lname, fontdict=dict(fontsize=fontsize))
            else:
                g.ax_heatmap.set_title(lname, fontdict=dict(fontsize=fontsize))

            g.ax_col_dendrogram.set_visible(False)  # gets rid of top free space

            g.ax_heatmap.yaxis.tick_left()
            g.ax_heatmap.yaxis.set_label_position("right")

            g.ax_heatmap.set_xlabel(xlabel)
            g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS))
            g.ax_heatmap.set_xticklabels(
                list(map(lambda n: round(n, 3), np.linspace(x_min, x_max, _N_XTICKS)))
            )

            if show_clust:
                # robustly show dendrogram, because gene names can be long
                g.ax_row_dendrogram.set_visible(True)
                dendro_box = g.ax_row_dendrogram.get_position()

                pad = 0.005
                bb = g.ax_heatmap.yaxis.get_tightbbox(
                    g.fig.canvas.get_renderer()
                ).transformed(g.fig.transFigure.inverted())

                dendro_box.x0 = bb.x0 - dendro_box.width - pad
                dendro_box.x1 = bb.x0 - pad

                g.ax_row_dendrogram.set_position(dendro_box)
            else:
                g.ax_row_dendrogram.set_visible(False)

            if return_genes:
                sorted_genes[lname] = (
                    df.index[g.dendrogram_row.reordered_ind]
                    if hasattr(g, "dendrogram_row") and g.dendrogram_row is not None
                    else df.index
                )

            figs.append(g)

        return figs, sorted_genes

    mode = HeatmapMode(mode)

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[lineage_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    lineages = _unique_order_preserving(lineages)

    _ = adata.obsm[lineage_key][lineages]

    if cluster_key is not None:
        if isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        cluster_key = _unique_order_preserving(cluster_key)

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False))

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    xlabel = time_key if xlabel is None else xlabel
    models = _create_models(model, genes, lineages)

    kwargs["backward"] = backward
    kwargs["time_key"] = time_key
    callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs)

    backend = _get_backend(model, backend)
    n_jobs = _get_n_cores(n_jobs, len(genes))

    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    data = parallelize(
        _fit_gene_trends,
        genes,
        unit="gene",
        backend=backend,
        n_jobs=n_jobs,
        extractor=lambda data: {k: v for d in data for k, v in d.items()},
        show_progress_bar=show_progress_bar,
    )(models, callbacks, lineages, time_range, **kwargs)
    logg.info("    Finish", time=start)

    logg.debug(f"Plotting `{mode.s!r}` heatmap")
    fig, genes = _plot_heatmap(mode)

    if save is not None and fig is not None:
        if not isinstance(fig, Iterable):
            save_fig(fig, save, ext=ext)
        elif len(fig) == 1:
            save_fig(fig[0], save, ext=ext)
        else:
            for ln, f in zip(lineages, fig):
                save_fig(f, os.path.join(save, f"lineage_{ln}"), ext=ext)

    if return_genes and mode == HeatmapMode.LINEAGES:
        return genes
예제 #8
0
def cluster_lineage(
        adata: AnnData,
        model: _model_type,
        genes: Sequence[str],
        lineage: str,
        backward: bool = False,
        time_range: _time_range_type = None,
        clusters: Optional[Sequence[str]] = None,
        n_points: int = 200,
        time_key: str = "latent_time",
        cluster_key: str = "clusters",
        norm: bool = True,
        recompute: bool = False,
        callback: _callback_type = None,
        ncols: int = 3,
        sharey: Union[str, bool] = False,
        key_added: Optional[str] = None,
        show_progress_bar: bool = True,
        n_jobs: Optional[int] = 1,
        backend: str = _DEFAULT_BACKEND,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
        neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
        louvain_kwargs: Dict = MappingProxyType({}),
        **kwargs,
) -> None:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential lineage
    drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered.
        Useful when plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    cluster_key
        Key in ``adata.obs`` where the clustering is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key_added
        Postfix to add when saving the results to ``adata.uns``.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    louvain_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain`.
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s

        Updates ``adata.uns`` with the following key:

            - ``lineage_{lineage}_trend_{key_added}`` - an :class:`anndata.AnnData` object
              of shape ``(n_genes, n_points)`` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    key_to_add = f"lineage_{lineage}_trend"
    if key_added is not None:
        logg.debug(f"Adding key `{key_added!r}`")
        key_to_add += f"_{key_added}"

    if recompute or key_to_add not in adata.uns:
        kwargs["time_key"] = time_key  # kwargs for the model.prepare
        kwargs["n_test_points"] = n_points
        kwargs["backward"] = backward

        models = _create_models(model, genes, [lineage])
        callbacks = _create_callbacks(adata, callback, genes, [lineage],
                                      **kwargs)

        backend = _get_backend(model, backend)
        n_jobs = _get_n_cores(n_jobs, len(genes))

        start = logg.info(f"Computing gene trends using `{n_jobs}` core(s)")
        trends = parallelize(
            _cluster_lineages_helper,
            genes,
            as_array=True,
            unit="gene",
            n_jobs=n_jobs,
            backend=backend,
            extractor=np.vstack,
            show_progress_bar=show_progress_bar,
        )(models, callbacks, lineage, time_range, **kwargs)
        logg.info("    Finish", time=start)

        trends = trends.T
        if norm:
            logg.debug("Normalizing using `StandardScaler`")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if n_points is not None and trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        pca_kwargs = dict(pca_kwargs)
        n_comps = pca_kwargs.pop(
            "n_comps",
            min(50, kwargs.get("n_test_points"), len(genes)) -
            1)  # default value

        sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        louvain_kwargs = dict(louvain_kwargs)
        louvain_kwargs["key_added"] = cluster_key
        sc.tl.louvain(trends, **louvain_kwargs)

        adata.uns[key_to_add] = trends
    else:
        logg.info(f"Loading data from `adata.uns[{key_to_add!r}]`")
        trends = adata.uns[key_to_add]

    if clusters is None:
        if cluster_key not in trends.obs:
            raise KeyError(f"Invalid cluster key `{cluster_key!r}`.")
        clusters = trends.obs[cluster_key].cat.categories

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs[cluster_key] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)
예제 #9
0
파일: _utils.py 프로젝트: changwn/cellrank
def _fit_bulk(
    models: Mapping[str, Mapping[str, Callable]],
    callbacks: Mapping[str, Mapping[str, Callable]],
    genes: Union[str, Sequence[str]],
    lineages: Union[str, Sequence[str]],
    time_range: _time_range_type,
    parallel_kwargs: dict,
    return_models: bool = False,
    filter_all_failed: bool = True,
    **kwargs,
) -> Tuple[_return_model_type, _return_model_type, Sequence[str],
           Sequence[str]]:
    """
    Fit models for given genes and lineages.

    Parameters
    ----------
    models
        Gene and lineage specific estimators.
    callbacks
        Functions which are called to prepare the ``models`.
    genes
        Genes for which to fit the ``models``.
    lineages
        Lineages for which to fit the ``models``.
    time_range
        Possibly ``lineages`` specific start- and endtimes.
    parallel_kwargs
        Keyword arguments for :func:`cellrank.ul._utils.parallelize`.
    return_models
        Whether to return the full models or just a dictionary of dictionaries of :class:`collections.namedtuple`,
        `(x_test, y_test)`. This is highly discouraged because no meaningful error messages will be produced.
    filter_all_failed
        Whether to filter out all models which have failed.

    Returns
    -------
    :class:`dict`
        All the models, including the failed ones. It is a nested dictionary where keys are the ``genes`` and the values
        is again a :class:`dict`, where keys are ``lineages`` and values are the failed or fitted models or
        the :class:`collections.namedtuple`, based on ``return_models=True``.
    :class:`dict`
        Same as above, but can contain failed models if ``filter_all_failed=False``. In that case, it is guaranteed
        that this dictionary will contain only genes which have been successfully fitted for at least 1 lineage.
        If ``return_models=True``, the models are just a :class:`collections.namedtuple` of `(x_test, y_test)`.
    :class:`tuple`
        All the genes of the filtered models.
    :class:`tuple`
        All the lineage of the filtered models.
    """

    if isinstance(genes, str):
        genes = [genes]

    if isinstance(lineages, str):
        lineages = [lineages]

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    n_jobs = parallel_kwargs.pop("n_jobs", 1)

    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    models = parallelize(
        _fit_bulk_helper,
        genes,
        unit="gene" if kwargs.get("data_key", "gene") != "obs" else "obs",
        n_jobs=n_jobs,
        extractor=lambda modelss:
        {k: v
         for m in modelss for k, v in m.items()},
    )(
        models=models,
        callbacks=callbacks,
        lineages=lineages,
        time_range=time_range,
        return_models=return_models,
        **kwargs,
    )
    logg.info("    Finish", time=start)

    return _filter_models(models,
                          return_models=return_models,
                          filter_all_failed=filter_all_failed)
예제 #10
0
def gene_trends(
        adata: AnnData,
        model: _model_type,
        genes: Union[str, Sequence[str]],
        lineages: Optional[Union[str, Sequence[str]]] = None,
        backward: bool = False,
        data_key: str = "X",
        time_key: str = "latent_time",
        time_range: Optional[Union[_time_range_type,
                                   List[_time_range_type]]] = None,
        callback: _callback_type = None,
        conf_int: bool = True,
        same_plot: bool = False,
        hide_cells: bool = False,
        perc: Optional[Union[Tuple[float, float],
                             Sequence[Tuple[float, float]]]] = None,
        lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None,
        abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis,
        cell_color: str = "black",
        cell_alpha: float = 0.6,
        lineage_alpha: float = 0.2,
        size: float = 15,
        lw: float = 2,
        show_cbar: bool = True,
        margins: float = 0.015,
        sharex: Optional[Union[str, bool]] = None,
        sharey: Optional[Union[str, bool]] = None,
        gene_as_title: Optional[bool] = None,
        legend_loc: Optional[str] = "best",
        ncols: int = 2,
        suptitle: Optional[str] = None,
        n_jobs: Optional[int] = 1,
        backend: str = _DEFAULT_BACKEND,
        show_progres_bar: bool = True,
        figsize: Optional[Tuple[float, float]] = None,
        dpi: Optional[int] = None,
        save: Optional[Union[str, Path]] = None,
        plot_kwargs: Mapping = MappingProxyType({}),
        **kwargs,
) -> None:
    """
    Plot gene expression trends along lineages.

    Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This
    function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression,
    where we take the lineage weights into account in the loss function.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages to plot. If `None`, plot all lineages.
    %(backward)s
    data_key
        Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    conf_int
        Whether to compute and show confidence intervals.
    same_plot
        Whether to plot all lineages for each gene in the same plot.
    hide_cells
        If `True`, hide all cells.
    perc
        Percentile for colors. Valid values are in interval `[0, 100]`.
        This can improve visualization. Can be specified individually for each lineage.
    lineage_cmap
        Colormap to use when coloring in the lineages. If `None` and ``same_plot``, use the corresponding colors
        in ``adata.uns``, otherwise use `'black'`.
    abs_prob_cmap
        Colormap to use when visualizing the absorption probabilities for each lineage.
        Only used when ``same_plot=False``.
    cell_color
        Color of the cells when not visualizing absorption probabilities. Only used when ``same_plot=True``.
    cell_alpha
        Alpha channel for cells.
    lineage_alpha
        Alpha channel for lineage confidence intervals.
    size
        Size of the points.
    lw
        Line width of the smoothed values.
    show_cbar
        Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``.
    margins
        Margins around the plot.
    sharex
        Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`.
    sharey
        Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`.
    gene_as_title
        Whether to show gene names as titles instead on y-axis.
    legend_loc
        Location of the legend displaying lineages. Only used when `same_plot=True`.
    ncols
        Number of columns of the plot when pl multiple genes. Only used when ``same_plot=True``.
    suptitle
        Suptitle of the figure.
    %(parallel)s
    %(plotting)s
    plot_kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`.
    **kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(just_plots)s
    """

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)

    if data_key != "obs":
        _check_collection(adata,
                          genes,
                          "var_names",
                          use_raw=kwargs.get("use_raw", False))
    else:
        _check_collection(adata,
                          genes,
                          "obs",
                          use_raw=kwargs.get("use_raw", False))

    ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if ln_key not in adata.obsm:
        raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[ln_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    elif all(map(lambda ln: ln is None,
                 lineages)):  # no lineage, all the weights are 1
        lineages = [None]
        show_cbar = False
        logg.debug("All lineages are `None`, setting the weights to `1`")
    lineages = _unique_order_preserving(lineages)

    if same_plot:
        gene_as_title = True if gene_as_title is None else gene_as_title
        sharex = "all" if sharex is None else sharex
        sharey = "none" if sharey is None else sharey
        ncols = len(genes) if ncols >= len(genes) else ncols
        nrows = int(np.ceil(len(genes) / ncols))
    else:
        gene_as_title = False if gene_as_title is None else gene_as_title
        sharex = "col" if sharex is None else sharex
        sharey = (
            "none" if hide_cells else "row") if sharey is None else sharey
        nrows = len(genes)
        ncols = len(lineages)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=sharex,
        sharey=sharey,
        figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize,
        constrained_layout=True,
    )
    axes = np.reshape(axes, (-1, ncols))

    _ = adata.obsm[ln_key][[lin for lin in lineages if lin is not None]]

    if isinstance(time_range, (tuple, float, int, type(None))):
        time_range = [time_range] * len(lineages)
    elif len(time_range) != len(lineages):
        raise ValueError(
            f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`."
        )

    kwargs["time_key"] = time_key
    kwargs["data_key"] = data_key
    kwargs["backward"] = backward
    callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs)

    kwargs["conf_int"] = conf_int  # prepare doesnt take or need this
    models = _create_models(model, genes, lineages)

    plot_kwargs = dict(plot_kwargs)
    if plot_kwargs.get("xlabel", None) is None:
        plot_kwargs["xlabel"] = time_key

    n_jobs = _get_n_cores(n_jobs, len(genes))
    backend = _get_backend(model, backend)

    start = logg.info(f"Computing trends using `{n_jobs}` core(s)")
    models = parallelize(
        _fit_gene_trends,
        genes,
        unit="gene" if data_key != "obs" else "obs",
        backend=backend,
        n_jobs=n_jobs,
        extractor=lambda modelss:
        {k: v
         for m in modelss for k, v in m.items()},
        show_progress_bar=show_progres_bar,
    )(models, callbacks, lineages, time_range, **kwargs)
    logg.info("    Finish", time=start)

    logg.info("Plotting trends")

    cnt = 0
    for row in range(len(axes)):
        for col in range(len(axes[row])):
            if cnt >= len(genes):
                break
            gene = genes[cnt]

            _trends_helper(
                adata,
                models,
                gene=gene,
                lineage_names=lineages,
                ln_key=ln_key,
                same_plot=same_plot,
                hide_cells=hide_cells,
                perc=perc,
                lineage_cmap=lineage_cmap,
                abs_prob_cmap=abs_prob_cmap,
                cell_color=cell_color,
                alpha=cell_alpha,
                lineage_alpha=lineage_alpha,
                size=size,
                lw=lw,
                show_cbar=show_cbar,
                margins=margins,
                sharey=sharey,
                gene_as_title=gene_as_title,
                legend_loc=legend_loc,
                dpi=dpi,
                figsize=figsize,
                fig=fig,
                axes=axes[row, col] if same_plot else axes[cnt],
                show_ylabel=col == 0,
                show_lineage=cnt == 0 or same_plot,
                show_xticks_and_label=((row + 1) * ncols + col >= len(genes))
                if same_plot else (cnt == len(axes) - 1),
                **plot_kwargs,
            )
            cnt += 1

    if same_plot and (col != ncols):
        for ax in np.ravel(axes)[cnt:]:
            ax.remove()

    fig.suptitle(suptitle)

    if save is not None:
        save_fig(fig, save)