def _run_in_parallel(fn: Callable, conn: csr_matrix, **kwargs) -> Any: fname = fn.__name__ if fname == "_run_stochastic": if not _HAS_JAX: raise RuntimeError( "Install `jax` and `jaxlib` as `pip install jax jaxlib`.") ixs = np.argsort(np.array((conn != 0).sum(1)).ravel())[::-1] else: ixs = np.arange(conn.shape[0]) np.random.shuffle(ixs) unit = ("sample" if (fname == "_run_mc") and kwargs.get("n_samples", 1) > 1 else "cell") kwargs["indices"] = conn.indices kwargs["indptr"] = conn.indptr return parallelize( fn, ixs, as_array=False, extractor=lambda res: _reconstruct_one(np.concatenate(res, axis=-1), conn, ixs), unit=unit, **_filter_kwargs(parallelize, **kwargs), )(**_filter_kwargs(fn, **kwargs))
def _( x: Union[np.ndarray, spmatrix], library_size: np.ndarray, ref_ix: Optional[int] = None, **kwargs, ) -> np.ndarray: if ref_ix is None: f75 = _calc_factor_quant(x, library_size=library_size, p=0.75) ref_ix = np.argmin(np.abs(f75 - np.mean(f75))) ref = x[ref_ix] if issparse(ref): ref = ref.A.squeeze(0) # (genes,) return parallelize( _calc_factor_weighted, collection=np.arange(x.shape[0]), show_progress_bar=False, as_array=False, extractor=np.concatenate, backend="threading", n_jobs=4, )( obs_=x, obs_lib_size_=library_size, ref=ref, ref_lib_size=library_size[ref_ix], **kwargs, )
def simulate_many( self, n_sims: int, max_iter: Union[int, float] = 0.25, seed: Optional[int] = None, successive_hits: int = 0, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = True, ) -> List[np.ndarray]: """ Simulate many random walks. Parameters ---------- n_sims Number of random walks to simulate. %(rw_sim.params)s %(parallel)s Returns ------- List of arrays of shape ``(max_iter + 1,)`` of states that have been visited. If ``stop_ixs`` was specified, the arrays may have smaller shape. """ if n_sims <= 0: raise ValueError( f"Expected number of simulations to be positive, found `{n_sims}`." ) max_iter = self._max_iter(max_iter) start = logg.info( f"Simulating `{n_sims}` random walks of maximum length `{max_iter}`" ) simss = parallelize( self._simulate_many, collection=np.arange(n_sims), n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, as_array=False, unit="sim", )(max_iter=max_iter, seed=seed, successive_hits=successive_hits) simss = list(chain.from_iterable(simss)) logg.info(" Finish", time=start) return simss
def test_more_jobs_than_work(self, n_jobs: int): def callback(data, **_: Any): assert isinstance(data, csr_matrix) assert data.shape[1] == 100 return [42] * data.shape[0] res = parallelize( callback, collection=srand(3, 100, format="csr"), n_jobs=n_jobs, show_progress_bar=False, extractor=np.concatenate, )() np.testing.assert_array_equal(res, 42)
def bias_knn( self, conn: csr_matrix, pseudotime: np.ndarray, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = True, **kwargs: Any, ) -> csr_matrix: """ Bias cell-cell connectivities of a KNN graph. Parameters ---------- conn Sparse matrix of shape ``(n_cells, n_cells)`` containing the nearest neighbor connectivities. pseudotime Pseudotemporal ordering of cells. %(parallel)s Returns ------- The biased connectivities. """ res = parallelize( self._bias_knn_helper, np.arange(conn.shape[0]), as_array=False, unit="cell", n_jobs=n_jobs, backend=backend, show_progress_bar=show_progress_bar, )(conn, pseudotime, **kwargs) data, indices, indptr = zip(*res) conn = csr_matrix((np.concatenate(data), np.concatenate(indices), np.concatenate(indptr))) conn.eliminate_zeros() return conn
def _solve_lin_system( mat_a: Union[np.ndarray, spmatrix], mat_b: Union[np.ndarray, spmatrix], solver: str = _DEFAULT_SOLVER, use_petsc: bool = False, preconditioner: Optional[str] = None, n_jobs: Optional[int] = None, backend: str = _DEFAULT_BACKEND, tol: float = 1e-5, use_eye: bool = False, show_progress_bar: bool = True, ) -> np.ndarray: """ Solve ``mat_a * x = mat_b`` efficiently using either iterative or direct methods. This is a utility function which is optimized for the case of ``mat_a`` and ``mat_b`` being sparse, and columns in ``mat_b`` being related. In that case, we can treat each column of ``mat_b`` as a separate linear problem and solve that efficiently using iterative solvers that exploit sparsity. If the columns of ``mat_b`` are related, we can use the solution of the previous problem as an initial guess for the next problem. Further, we parallelize the individual problems for each column in ``mat_b`` and solve them on separate kernels. In case ``mat_a`` is either not sparse, or very small, or ``mat_b`` has very many columns, it makes sense to use a direct solver instead which computes a matrix factorization and thereby solves all sub-problems at the same time. Parameters ---------- mat_a Matrix of shape `n x n`. We make no assumptions on ``mat_a`` being symmetric or positive definite. mat_b Matrix of shape `n x m`, with m << n. solver Solver to use for the linear problem. Options are `'direct', 'gmres', 'lgmres', 'bicgstab' or 'gcrotmk'` when ``use_petsc`` or one of `petsc4py.PETSc.KPS.Type` otherwise. Information on the :mod:`scipy` iterative solvers can be found in :func:`scipy.sparse.linalg` or for the :mod:`petsc4py` solver in https://www.mcs.anl.gov/petsc/documentation/linearsolvertable.html. use_petsc Whether to use solvers from :mod:`petsc4py` instead of :mod:`scipy`. Recommended for large problems. preconditioner Preconditioner to use when ``use_petsc=True``. For available preconditioners, see `petsc4py.PETSc.PC.Type`. n_jobs Number of parallel jobs to use when ``use_petsc=True``. For small, quickly-solvable problems, we recommend high number (>=8) of cores in order to fully saturate them. backend Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. tol The relative convergence tolerance, relative decrease in the (possibly preconditioned) residual norm . use_eye Solve ``(I - mat_a) * x = mat_b`` instead. show_progress_bar Whether to show progress bar when the solver isn't a direct one. Returns -------- :class:`numpy.ndarray` Matrix of shape `n x m`. Each column corresponds to the solution of one of the sub-problems defined via columns in ``mat_b``. """ def extractor( res_converged: List[Tuple[np.ndarray, int]]) -> Tuple[np.ndarray, int]: res, converged = zip(*res_converged) return np.hstack(res), sum(converged) n_jobs = _get_n_cores(n_jobs, n_jobs=None) if use_petsc: try: from petsc4py import PETSc # noqa except ImportError: global _PETSC_ERROR_MSG_SHOWN if not _PETSC_ERROR_MSG_SHOWN: _PETSC_ERROR_MSG_SHOWN = True logg.warning(_PETSC_ERROR_MSG.format(_DEFAULT_SOLVER)) solver = _DEFAULT_SOLVER use_petsc = False if use_eye: mat_a = (speye(mat_a.shape[0]) if issparse(mat_a) else np.eye(mat_a.shape[0])) - mat_a if solver == "direct": if use_petsc: logg.debug("Solving the linear system directly using `PETSc`") return _petsc_mat_solve(mat_a, mat_b, solver=solver, preconditioner=preconditioner, tol=tol) if issparse(mat_a): logg.debug("Densifying `A` for `scipy` direct solver") mat_a = mat_a.toarray() if issparse(mat_b): logg.debug("Densifying `B` for `scipy` direct solver") mat_b = mat_b.toarray() logg.debug("Solving the linear system directly using `scipy`") return solve(mat_a, mat_b) if use_petsc: if not isspmatrix_csr(mat_a): mat_a = csr_matrix(mat_a) mat_b = mat_b.T if not isspmatrix_csc(mat_b): mat_b = csc_matrix(mat_b) # as_array causes an issue, because it's called like this np.array([(NxM), (NxK), ....] # in the end, we want array of shape Nx(M + K + ...) - this is ensured by the extractor logg.debug( f"Solving the linear system using `PETSc` solver `{('gmres' if solver is None else solver)!r}` " f"on `{n_jobs}` core(s) with {'no' if preconditioner is None else preconditioner} preconditioner and " f"`tol={tol}`") # can't pass PETSc matrix - not pickleable mat_x, n_converged = parallelize( _solve_many_sparse_problems_petsc, mat_b, n_jobs=n_jobs, backend=backend, as_array=False, extractor=extractor, show_progress_bar=show_progress_bar, )(mat_a, solver=solver, preconditioner=preconditioner, tol=tol) elif solver in _AVAIL_ITER_SOLVERS: if not issparse(mat_a): logg.debug("Sparsifying `A` for iterative solver") mat_a = csr_matrix(mat_a) mat_b = mat_b.T if not issparse(mat_b): logg.debug("Sparsifying `B` for iterative solver") mat_b = csr_matrix(mat_b) logg.debug( f"Solving the linear system using `scipy` solver `{solver!r}` on `{n_jobs} cores(s)` with `tol={tol}`" ) mat_x, n_converged = parallelize( _solve_many_sparse_problems, mat_b, n_jobs=n_jobs, backend=backend, as_array=False, extractor=extractor, show_progress_bar=show_progress_bar, )(mat_a, solver=_AVAIL_ITER_SOLVERS[solver], tol=tol) else: raise ValueError(f"Invalid solver `{solver!r}`.") if n_converged != mat_b.shape[0]: logg.warning( f"`{mat_b.shape[0] - n_converged}` solution(s) did not converge") return mat_x
def heatmap( adata: AnnData, model: _model_type, genes: Sequence[str], lineages: Optional[Union[str, Sequence[str]]] = None, backward: bool = False, mode: str = HeatmapMode.LINEAGES.s, time_key: str = "latent_time", time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None, callback: _callback_type = None, cluster_key: Optional[Union[str, Sequence[str]]] = None, show_absorption_probabilities: bool = False, cluster_genes: bool = False, keep_gene_order: bool = False, scale: bool = True, n_convolve: Optional[int] = 5, show_all_genes: bool = False, show_cbar: bool = True, lineage_height: float = 0.33, fontsize: Optional[float] = None, xlabel: Optional[str] = None, cmap: mcolors.ListedColormap = cm.viridis, show_dendrogram: bool = True, return_genes: bool = False, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, show_progress_bar: bool = True, ext: str = "png", figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs, ) -> Optional[Dict[str, pd.DataFrame]]: """ Plot a heatmap of smoothed gene expression along specified lineages. Parameters ---------- %(adata)s %(model)s %(genes)s lineages Names of the lineages for which to plot. If `None`, plot all lineages. %(backward)s mode Valid options are: - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``. - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``. time_key Key in ``adata.obs`` where the pseudotime is stored. %(time_ranges)s %(model_callback)s cluster_key Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap. Only available when ``mode={m.LINEAGES.s!r}``. show_absorption_probabilities Whether to also plot absorption probabilities alongside the smoothed expression. Only available when ``mode={m.LINEAGES.s!r}``. cluster_genes Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``. keep_gene_order Whether to keep the gene order for later lineages after the first was sorted. Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``. scale Whether to normalize the gene expression `0-1` range. n_convolve Size of the convolution window when smoothing absorption probabilities. show_all_genes Whether to show all genes on y-axis. show_cbar Whether to show the colorbar. lineage_height Height of a bar when ``mode={m.GENES.s!r}``. fontsize Size of the title's font. xlabel Label on the x-axis. If `None`, it is determined based on ``time_key``. cmap Colormap to use when visualizing the smoothed expression. show_dendrogram Whether to show dendrogram when ``cluster_genes=True``. return_genes Whether to return the sorted or clustered genes. Only available when ``mode={m.LINEAGES.s!r}``. %(parallel)s %(plotting)s **kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(just_plots)s :class:`pandas.DataFrame` If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame` containing the clustered or sorted genes. """ import seaborn as sns def find_indices(series: pd.Series, values) -> Tuple[Any]: def find_nearest(array: np.ndarray, value: float) -> int: ix = np.searchsorted(array, value, side="left") if ix > 0 and ( ix == len(array) or fabs(value - array[ix - 1]) < fabs(value - array[ix]) ): return ix - 1 return ix series = series[np.argsort(series.values)] return tuple(series[[find_nearest(series.values, v) for v in values]].index) def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray: time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) lin = adata[ixs, :].obsm[lineage_key][lname] lin = lin.X.copy().squeeze() if n_convolve is not None: lin = convolve(lin, np.ones(n_convolve) / n_convolve, mode="nearest") return lin def create_col_colors(lname: str, rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]: color = adata.obsm[lineage_key][lname].colors[0] lin = subset_lineage(lname, rng) h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color)) end_color = mcolors.hsv_to_rgb([h, 1, v]) lineage_cmap = mcolors.LinearSegmentedColormap.from_list( "lineage_cmap", ["#ffffff", end_color], N=len(rng) ) norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin)) scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm) return ( np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]), lineage_cmap, norm, ) def create_col_categorical_color(cluster_key: str, rng: np.ndarray) -> np.ndarray: if not is_categorical_dtype(adata.obs[cluster_key]): raise TypeError( f"Expected `adata.obs[{cluster_key!r}]` to be categorical, " f"found `{adata.obs[cluster_key].dtype.name!r}`." ) color_key = f"{cluster_key}_colors" if color_key not in adata.uns: logg.warning( f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors" ) colors = _create_categorical_colors( len(adata.obs[cluster_key].cat.categories) ) adata.uns[color_key] = colors else: colors = adata.uns[color_key] time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors)) return np.array( [mcolors.to_hex(mapper[v]) for v in adata[ixs, :].obs[cluster_key].values] ) def create_cbar(ax, x_delta: float, cmap, norm, label=None) -> Ax: cax = inset_axes( ax, width="1%", height="100%", loc="lower right", bbox_to_anchor=(x_delta, 0, 1, 1), bbox_transform=ax.transAxes, ) _ = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, label=label) return cax @valuedispatch def _plot_heatmap(_mode: HeatmapMode) -> Fig: pass @_plot_heatmap.register(HeatmapMode.GENES) def _() -> Tuple[Fig, None]: def color_fill_rec(ax, xs, y1, y2, colors=None, cmap=cmap, **kwargs) -> None: colors = colors if cmap is None else cmap(colors) x = 0 for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)): dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2]) ax.add_patch( plt.Rectangle((x, y1), dx, y2 - y1, color=color, ec=color, **kwargs) ) ax.plot(x, y2, lw=0) fig, axes = plt.subplots( nrows=len(genes) + show_absorption_probabilities, figsize=(12, len(genes) + len(lineages) * lineage_height) if figsize is None else figsize, dpi=dpi, constrained_layout=True, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) if show_absorption_probabilities: data["absorption probability"] = data[next(iter(data.keys()))] for ax, (gene, models) in zip(axes, data.items()): if scale: norm = mcolors.Normalize(vmin=0, vmax=1) else: c = np.array([m.y_test for m in models.values()]) c_min, c_max = np.nanmin(c), np.nanmax(c) norm = mcolors.Normalize(vmin=c_min, vmax=c_max) ix = 0 ys = [ix] if gene == "absorption probability": norm = mcolors.Normalize(vmin=0, vmax=1) for ln, x in ((ln, m.x_test) for ln, m in models.items()): y = np.ones_like(x) c = subset_lineage(ln, x.squeeze()) color_fill_rec( ax, x, y * ix, y * (ix + lineage_height), colors=norm(c) ) ix += lineage_height ys.append(ix) else: for x, c in ((m.x_test, m.y_test) for m in models.values()): y = np.ones_like(x) c = _min_max_scale(c) if scale else c color_fill_rec( ax, x, y * ix, y * (ix + lineage_height), colors=norm(c) ) ix += lineage_height ys.append(ix) xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.min(xs), np.max(xs) ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS)) ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2) ax.spines["left"].set_position( ("data", 0) ) # move the left spine to the rectangles to get nicer yticks ax.set_yticklabels(lineages, ha="right") ax.set_title(gene, fontdict=dict(fontsize=fontsize)) ax.set_ylabel("lineage") for pos in ["top", "bottom", "left", "right"]: ax.spines[pos].set_visible(False) cax, _ = mpl.colorbar.make_axes(ax) _ = mpl.colorbar.ColorbarBase( cax, norm=norm, cmap=cmap, label="value" if gene == "absorption probability" else "expression", ) ax.tick_params( top=False, bottom=False, left=True, right=False, labelleft=True, labelbottom=False, ) ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f")) ax.tick_params( top=False, bottom=True, left=True, right=False, labelleft=True, labelbottom=True, ) ax.set_xlabel(xlabel) return fig, None @_plot_heatmap.register(HeatmapMode.LINEAGES) def _() -> Tuple[List[Fig], pd.DataFrame]: data_t = defaultdict(dict) # transpose for gene, lns in data.items(): for ln, y in lns.items(): data_t[ln][gene] = y figs = [] gene_order = None sorted_genes = pd.DataFrame() if return_genes else None for lname, models in data_t.items(): xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.nanmin(xs), np.nanmax(xs) df = pd.DataFrame([m.y_test for m in models.values()], index=genes) df.index.name = "genes" if not cluster_genes: if gene_order is not None: df = df.loc[gene_order] else: max_sort = np.argsort( np.argmax(df.apply(_min_max_scale, axis=1).values, axis=1) ) df = df.iloc[max_sort, :] if keep_gene_order: gene_order = df.index cat_colors = None if cluster_key is not None: cat_colors = np.stack( [ create_col_categorical_color( c, np.linspace(x_min, x_max, df.shape[1]) ) for c in cluster_key ], axis=0, ) if show_absorption_probabilities: col_colors, col_cmap, col_norm = create_col_colors( lname, np.linspace(x_min, x_max, df.shape[1]) ) if cat_colors is not None: col_colors = np.vstack([cat_colors, col_colors[None, :]]) else: col_colors, col_cmap, col_norm = cat_colors, None, None row_cluster = cluster_genes and df.shape[0] > 1 show_clust = row_cluster and show_dendrogram g = sns.clustermap( df, cmap=cmap, figsize=(10, min(len(genes) / 8 + 1, 10)) if figsize is None else figsize, xticklabels=False, cbar_kws={"label": "expression"}, row_cluster=cluster_genes and df.shape[0] > 1, col_colors=col_colors, colors_ratio=0, col_cluster=False, cbar_pos=None, yticklabels=show_all_genes or "auto", standard_scale=0 if scale else None, ) if show_cbar: cax = create_cbar( g.ax_heatmap, 0.1, cmap=cmap, norm=mcolors.Normalize( vmin=0 if scale else np.min(df.values), vmax=1 if scale else np.max(df.values), ), label="expression", ) g.fig.add_axes(cax) if col_cmap is not None and col_norm is not None: cax = create_cbar( g.ax_heatmap, 0.25, cmap=col_cmap, norm=col_norm, label="absorption probability", ) g.fig.add_axes(cax) if g.ax_col_colors: main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap) n_bars = show_absorption_probabilities + ( len(cluster_key) if cluster_key is not None else 0 ) _set_ax_height_to_cm( g.fig, g.ax_col_colors, height=min( 5, max(n_bars * main_bbox.height / len(df), 0.25 * n_bars) ), ) g.ax_col_colors.set_title(lname, fontdict=dict(fontsize=fontsize)) else: g.ax_heatmap.set_title(lname, fontdict=dict(fontsize=fontsize)) g.ax_col_dendrogram.set_visible(False) # gets rid of top free space g.ax_heatmap.yaxis.tick_left() g.ax_heatmap.yaxis.set_label_position("right") g.ax_heatmap.set_xlabel(xlabel) g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS)) g.ax_heatmap.set_xticklabels( list(map(lambda n: round(n, 3), np.linspace(x_min, x_max, _N_XTICKS))) ) if show_clust: # robustly show dendrogram, because gene names can be long g.ax_row_dendrogram.set_visible(True) dendro_box = g.ax_row_dendrogram.get_position() pad = 0.005 bb = g.ax_heatmap.yaxis.get_tightbbox( g.fig.canvas.get_renderer() ).transformed(g.fig.transFigure.inverted()) dendro_box.x0 = bb.x0 - dendro_box.width - pad dendro_box.x1 = bb.x0 - pad g.ax_row_dendrogram.set_position(dendro_box) else: g.ax_row_dendrogram.set_visible(False) if return_genes: sorted_genes[lname] = ( df.index[g.dendrogram_row.reordered_ind] if hasattr(g, "dendrogram_row") and g.dendrogram_row is not None else df.index ) figs.append(g) return figs, sorted_genes mode = HeatmapMode(mode) lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError(f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[lineage_key].names elif isinstance(lineages, str): lineages = [lineages] lineages = _unique_order_preserving(lineages) _ = adata.obsm[lineage_key][lineages] if cluster_key is not None: if isinstance(cluster_key, str): cluster_key = [cluster_key] cluster_key = _unique_order_preserving(cluster_key) if isinstance(genes, str): genes = [genes] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False)) if isinstance(time_range, (tuple, float, int, type(None))): time_range = [time_range] * len(lineages) elif len(time_range) != len(lineages): raise ValueError( f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`." ) xlabel = time_key if xlabel is None else xlabel models = _create_models(model, genes, lineages) kwargs["backward"] = backward kwargs["time_key"] = time_key callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs) backend = _get_backend(model, backend) n_jobs = _get_n_cores(n_jobs, len(genes)) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") data = parallelize( _fit_gene_trends, genes, unit="gene", backend=backend, n_jobs=n_jobs, extractor=lambda data: {k: v for d in data for k, v in d.items()}, show_progress_bar=show_progress_bar, )(models, callbacks, lineages, time_range, **kwargs) logg.info(" Finish", time=start) logg.debug(f"Plotting `{mode.s!r}` heatmap") fig, genes = _plot_heatmap(mode) if save is not None and fig is not None: if not isinstance(fig, Iterable): save_fig(fig, save, ext=ext) elif len(fig) == 1: save_fig(fig[0], save, ext=ext) else: for ln, f in zip(lineages, fig): save_fig(f, os.path.join(save, f"lineage_{ln}"), ext=ext) if return_genes and mode == HeatmapMode.LINEAGES: return genes
def cluster_lineage( adata: AnnData, model: _model_type, genes: Sequence[str], lineage: str, backward: bool = False, time_range: _time_range_type = None, clusters: Optional[Sequence[str]] = None, n_points: int = 200, time_key: str = "latent_time", cluster_key: str = "clusters", norm: bool = True, recompute: bool = False, callback: _callback_type = None, ncols: int = 3, sharey: Union[str, bool] = False, key_added: Optional[str] = None, show_progress_bar: bool = True, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}), neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}), louvain_kwargs: Dict = MappingProxyType({}), **kwargs, ) -> None: """ Cluster gene expression trends within a lineage and plot the clusters. This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive development along a given lineage. Consider running this function on a subset of genes which are potential lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`. Parameters ---------- %(adata)s %(model)s %(genes)s lineage Name of the lineage for which to cluster the genes. %(backward)s %(time_ranges)s clusters Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when plotting previously computed clusters. n_points Number of points used for prediction. time_key Key in ``adata.obs`` where the pseudotime is stored. cluster_key Key in ``adata.obs`` where the clustering is stored. norm Whether to z-normalize each trend to have zero mean, unit variance. recompute If `True`, recompute the clustering, otherwise try to find already existing one. %(model_callback)s ncols Number of columns for the plot. sharey Whether to share y-axis across multiple plots. key_added Postfix to add when saving the results to ``adata.uns``. %(parallel)s %(plotting)s pca_kwargs Keyword arguments for :func:`scanpy.pp.pca`. neighbors_kwargs Keyword arguments for :func:`scanpy.pp.neighbors`. louvain_kwargs Keyword arguments for :func:`scanpy.tl.louvain`. **kwargs: Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(just_plots)s Updates ``adata.uns`` with the following key: - ``lineage_{lineage}_trend_{key_added}`` - an :class:`anndata.AnnData` object of shape ``(n_genes, n_points)`` containing the clustered genes. """ import scanpy as sc from anndata import AnnData as _AnnData lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") _ = adata.obsm[lineage_key][lineage] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False)) key_to_add = f"lineage_{lineage}_trend" if key_added is not None: logg.debug(f"Adding key `{key_added!r}`") key_to_add += f"_{key_added}" if recompute or key_to_add not in adata.uns: kwargs["time_key"] = time_key # kwargs for the model.prepare kwargs["n_test_points"] = n_points kwargs["backward"] = backward models = _create_models(model, genes, [lineage]) callbacks = _create_callbacks(adata, callback, genes, [lineage], **kwargs) backend = _get_backend(model, backend) n_jobs = _get_n_cores(n_jobs, len(genes)) start = logg.info(f"Computing gene trends using `{n_jobs}` core(s)") trends = parallelize( _cluster_lineages_helper, genes, as_array=True, unit="gene", n_jobs=n_jobs, backend=backend, extractor=np.vstack, show_progress_bar=show_progress_bar, )(models, callbacks, lineage, time_range, **kwargs) logg.info(" Finish", time=start) trends = trends.T if norm: logg.debug("Normalizing using `StandardScaler`") _ = StandardScaler(copy=False).fit_transform(trends) trends = _AnnData(trends.T) trends.obs_names = genes # sanity check if trends.n_obs != len(genes): raise RuntimeError( f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`." ) if n_points is not None and trends.n_vars != n_points: raise RuntimeError( f"Expected to find `{n_points}` points, found `{trends.n_vars}`." ) pca_kwargs = dict(pca_kwargs) n_comps = pca_kwargs.pop( "n_comps", min(50, kwargs.get("n_test_points"), len(genes)) - 1) # default value sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs) sc.pp.neighbors(trends, **neighbors_kwargs) louvain_kwargs = dict(louvain_kwargs) louvain_kwargs["key_added"] = cluster_key sc.tl.louvain(trends, **louvain_kwargs) adata.uns[key_to_add] = trends else: logg.info(f"Loading data from `adata.uns[{key_to_add!r}]`") trends = adata.uns[key_to_add] if clusters is None: if cluster_key not in trends.obs: raise KeyError(f"Invalid cluster key `{cluster_key!r}`.") clusters = trends.obs[cluster_key].cat.categories nrows = int(np.ceil(len(clusters) / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * 10, nrows * 10) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) j = 0 for j, (ax, c) in enumerate(zip(axes, clusters)): # noqa data = trends[trends.obs[cluster_key] == c].X mean, sd = np.mean(data, axis=0), np.var(data, axis=0) sd = np.sqrt(sd) for i in range(data.shape[0]): ax.plot(data[i], color="gray", lw=0.5) ax.plot(mean, lw=2, color="black") ax.plot(mean - sd, lw=1.5, color="black", linestyle="--") ax.plot(mean + sd, lw=1.5, color="black", linestyle="--") ax.fill_between(range(len(mean)), mean - sd, mean + sd, color="black", alpha=0.1) ax.set_title(f"Cluster {c}") ax.set_xticks([]) if not sharey: ax.set_yticks([]) for j in range(j + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save)
def _fit_bulk( models: Mapping[str, Mapping[str, Callable]], callbacks: Mapping[str, Mapping[str, Callable]], genes: Union[str, Sequence[str]], lineages: Union[str, Sequence[str]], time_range: _time_range_type, parallel_kwargs: dict, return_models: bool = False, filter_all_failed: bool = True, **kwargs, ) -> Tuple[_return_model_type, _return_model_type, Sequence[str], Sequence[str]]: """ Fit models for given genes and lineages. Parameters ---------- models Gene and lineage specific estimators. callbacks Functions which are called to prepare the ``models`. genes Genes for which to fit the ``models``. lineages Lineages for which to fit the ``models``. time_range Possibly ``lineages`` specific start- and endtimes. parallel_kwargs Keyword arguments for :func:`cellrank.ul._utils.parallelize`. return_models Whether to return the full models or just a dictionary of dictionaries of :class:`collections.namedtuple`, `(x_test, y_test)`. This is highly discouraged because no meaningful error messages will be produced. filter_all_failed Whether to filter out all models which have failed. Returns ------- :class:`dict` All the models, including the failed ones. It is a nested dictionary where keys are the ``genes`` and the values is again a :class:`dict`, where keys are ``lineages`` and values are the failed or fitted models or the :class:`collections.namedtuple`, based on ``return_models=True``. :class:`dict` Same as above, but can contain failed models if ``filter_all_failed=False``. In that case, it is guaranteed that this dictionary will contain only genes which have been successfully fitted for at least 1 lineage. If ``return_models=True``, the models are just a :class:`collections.namedtuple` of `(x_test, y_test)`. :class:`tuple` All the genes of the filtered models. :class:`tuple` All the lineage of the filtered models. """ if isinstance(genes, str): genes = [genes] if isinstance(lineages, str): lineages = [lineages] if isinstance(time_range, (tuple, float, int, type(None))): time_range = [time_range] * len(lineages) elif len(time_range) != len(lineages): raise ValueError( f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`." ) n_jobs = parallel_kwargs.pop("n_jobs", 1) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") models = parallelize( _fit_bulk_helper, genes, unit="gene" if kwargs.get("data_key", "gene") != "obs" else "obs", n_jobs=n_jobs, extractor=lambda modelss: {k: v for m in modelss for k, v in m.items()}, )( models=models, callbacks=callbacks, lineages=lineages, time_range=time_range, return_models=return_models, **kwargs, ) logg.info(" Finish", time=start) return _filter_models(models, return_models=return_models, filter_all_failed=filter_all_failed)
def gene_trends( adata: AnnData, model: _model_type, genes: Union[str, Sequence[str]], lineages: Optional[Union[str, Sequence[str]]] = None, backward: bool = False, data_key: str = "X", time_key: str = "latent_time", time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None, callback: _callback_type = None, conf_int: bool = True, same_plot: bool = False, hide_cells: bool = False, perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float, float]]]] = None, lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None, abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis, cell_color: str = "black", cell_alpha: float = 0.6, lineage_alpha: float = 0.2, size: float = 15, lw: float = 2, show_cbar: bool = True, margins: float = 0.015, sharex: Optional[Union[str, bool]] = None, sharey: Optional[Union[str, bool]] = None, gene_as_title: Optional[bool] = None, legend_loc: Optional[str] = "best", ncols: int = 2, suptitle: Optional[str] = None, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, show_progres_bar: bool = True, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, plot_kwargs: Mapping = MappingProxyType({}), **kwargs, ) -> None: """ Plot gene expression trends along lineages. Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression, where we take the lineage weights into account in the loss function. Parameters ---------- %(adata)s %(model)s %(genes)s lineages Names of the lineages to plot. If `None`, plot all lineages. %(backward)s data_key Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored. time_key Key in ``adata.obs`` where the pseudotime is stored. %(time_ranges)s %(model_callback)s conf_int Whether to compute and show confidence intervals. same_plot Whether to plot all lineages for each gene in the same plot. hide_cells If `True`, hide all cells. perc Percentile for colors. Valid values are in interval `[0, 100]`. This can improve visualization. Can be specified individually for each lineage. lineage_cmap Colormap to use when coloring in the lineages. If `None` and ``same_plot``, use the corresponding colors in ``adata.uns``, otherwise use `'black'`. abs_prob_cmap Colormap to use when visualizing the absorption probabilities for each lineage. Only used when ``same_plot=False``. cell_color Color of the cells when not visualizing absorption probabilities. Only used when ``same_plot=True``. cell_alpha Alpha channel for cells. lineage_alpha Alpha channel for lineage confidence intervals. size Size of the points. lw Line width of the smoothed values. show_cbar Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``. margins Margins around the plot. sharex Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`. sharey Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`. gene_as_title Whether to show gene names as titles instead on y-axis. legend_loc Location of the legend displaying lineages. Only used when `same_plot=True`. ncols Number of columns of the plot when pl multiple genes. Only used when ``same_plot=True``. suptitle Suptitle of the figure. %(parallel)s %(plotting)s plot_kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`. **kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(just_plots)s """ if isinstance(genes, str): genes = [genes] genes = _unique_order_preserving(genes) if data_key != "obs": _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False)) else: _check_collection(adata, genes, "obs", use_raw=kwargs.get("use_raw", False)) ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if ln_key not in adata.obsm: raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[ln_key].names elif isinstance(lineages, str): lineages = [lineages] elif all(map(lambda ln: ln is None, lineages)): # no lineage, all the weights are 1 lineages = [None] show_cbar = False logg.debug("All lineages are `None`, setting the weights to `1`") lineages = _unique_order_preserving(lineages) if same_plot: gene_as_title = True if gene_as_title is None else gene_as_title sharex = "all" if sharex is None else sharex sharey = "none" if sharey is None else sharey ncols = len(genes) if ncols >= len(genes) else ncols nrows = int(np.ceil(len(genes) / ncols)) else: gene_as_title = False if gene_as_title is None else gene_as_title sharex = "col" if sharex is None else sharex sharey = ( "none" if hide_cells else "row") if sharey is None else sharey nrows = len(genes) ncols = len(lineages) fig, axes = plt.subplots( nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize, constrained_layout=True, ) axes = np.reshape(axes, (-1, ncols)) _ = adata.obsm[ln_key][[lin for lin in lineages if lin is not None]] if isinstance(time_range, (tuple, float, int, type(None))): time_range = [time_range] * len(lineages) elif len(time_range) != len(lineages): raise ValueError( f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`." ) kwargs["time_key"] = time_key kwargs["data_key"] = data_key kwargs["backward"] = backward callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs) kwargs["conf_int"] = conf_int # prepare doesnt take or need this models = _create_models(model, genes, lineages) plot_kwargs = dict(plot_kwargs) if plot_kwargs.get("xlabel", None) is None: plot_kwargs["xlabel"] = time_key n_jobs = _get_n_cores(n_jobs, len(genes)) backend = _get_backend(model, backend) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") models = parallelize( _fit_gene_trends, genes, unit="gene" if data_key != "obs" else "obs", backend=backend, n_jobs=n_jobs, extractor=lambda modelss: {k: v for m in modelss for k, v in m.items()}, show_progress_bar=show_progres_bar, )(models, callbacks, lineages, time_range, **kwargs) logg.info(" Finish", time=start) logg.info("Plotting trends") cnt = 0 for row in range(len(axes)): for col in range(len(axes[row])): if cnt >= len(genes): break gene = genes[cnt] _trends_helper( adata, models, gene=gene, lineage_names=lineages, ln_key=ln_key, same_plot=same_plot, hide_cells=hide_cells, perc=perc, lineage_cmap=lineage_cmap, abs_prob_cmap=abs_prob_cmap, cell_color=cell_color, alpha=cell_alpha, lineage_alpha=lineage_alpha, size=size, lw=lw, show_cbar=show_cbar, margins=margins, sharey=sharey, gene_as_title=gene_as_title, legend_loc=legend_loc, dpi=dpi, figsize=figsize, fig=fig, axes=axes[row, col] if same_plot else axes[cnt], show_ylabel=col == 0, show_lineage=cnt == 0 or same_plot, show_xticks_and_label=((row + 1) * ncols + col >= len(genes)) if same_plot else (cnt == len(axes) - 1), **plot_kwargs, ) cnt += 1 if same_plot and (col != ncols): for ax in np.ravel(axes)[cnt:]: ax.remove() fig.suptitle(suptitle) if save is not None: save_fig(fig, save)