def _maybe_convert_names( self, names: Iterable[Union[int, str, bool]], is_singleton: bool = False, default: Optional[Union[int, str]] = None, make_unique: bool = True, ) -> Union[int, List[int], List[bool]]: if all(map(lambda n: isinstance(n, (bool, np.bool_)), names)): return list(names) res = [] for name in names: if isinstance(name, str): if name in self._names_to_ixs: name = self._names_to_ixs[name] elif default is not None: if isinstance(default, str): if default not in self._names_to_ixs: raise KeyError( f"Invalid lineage name: `{name}`. " f"Valid names are: `{list(self.names)}`.") name = self._names_to_ixs[default] else: name = default else: raise KeyError(f"Invalid lineage name `{name}`. " f"Valid names are: `{list(self.names)}`.") res.append(name) if make_unique: res = _unique_order_preserving(res) return res[0] if is_singleton else res
def _mixer(self, rows, mixtures): def update_entries(key): if key: res.append(self[rows, key].X.sum(1)) # item = (key, rows) if self._is_transposed else (rows, key) # res.append(self[item].X.sum(int(not self._is_transposed))) names.append(" or ".join(self.names[key])) colors.append(_compute_mean_color(self.colors[key])) lin_kind = [_ for _ in mixtures if isinstance(_, Lin)] if len(lin_kind) > 1: raise ValueError( f"`Lin` enum is allowed only once in the expression, found `{lin_kind}`." ) keys = [ tuple( self._maybe_convert_names( _convert_lineage_name(mixture), default=mixture ) ) if isinstance(mixture, str) else (mixture,) for mixture in mixtures if not (isinstance(mixture, Lin)) ] keys = _unique_order_preserving(keys) # check the `keys` are unique overlap = [set(ks) for ks in keys] for c1, c2 in combinations(overlap, 2): overlap = c1 & c2 if overlap: raise ValueError( f"Found overlapping keys: `{self.names[list(overlap)]}`." ) seen = set() names, colors, res = [], [], [] for key in map(list, keys): seen.update(self.names[key]) update_entries(key) if len(lin_kind) == 1: lin_kind = lin_kind[0] keys = [i for i, n in enumerate(self.names) if n not in seen] if lin_kind == Lin.OTHERS: for key in keys: update_entries([key]) elif lin_kind == Lin.REST: update_entries(keys) if keys: names[-1] = str(lin_kind) else: raise ValueError(f"Invalid `Lin` enum `{lin_kind}`.") res = np.stack(res, axis=-1) return Lineage(res, names=names, colors=colors)
def _get_sorted_colors( adata: AnnData, cluster_key: Union[str, Sequence[str]], time_key: Optional[str] = None, tmin: float = -np.inf, tmax: float = np.inf, ) -> List[np.ndarray]: if time_key is not None: if time_key not in adata.obs: raise KeyError(f"Unable to find time in `adata.obs[{time_key!r}]`.") adata = adata[(adata.obs[time_key] >= tmin) & (adata.obs[time_key] <= tmax)] if not adata.n_obs: raise ValueError( f"Specified time range `{[tmin, tmax]}` does not contain any data." ) order = np.argsort(adata.obs[time_key].values) else: order = np.arange(adata.n_obs) if isinstance(cluster_key, str): cluster_key = (cluster_key,) cluster_key = _unique_order_preserving(cluster_key) res = [] for ck in cluster_key: try: colors, mapper = _get_categorical_colors(adata, ck) res.append( np.array( [mcolors.to_hex(mapper[v]) for v in adata.obs[ck].values[order]] ) ) except TypeError: if not is_numeric_dtype(adata.obs[ck]): raise TypeError( f"Expected `adata.obs[{cluster_key!r}]` to be numeric, " f"found `{infer_dtype(adata.obs[cluster_key])}`." ) res.append(np.asarray(adata.obs[ck])[order]) return res
def _create_callbacks( adata: AnnData, callback: Optional[Callable], obs: Sequence[str], lineages: Sequence[Optional[str]], perform_sanity_check: Optional[bool] = None, **kwargs, ) -> Dict[str, Dict[str, Callable]]: """ Create models for each gene and lineage. Parameters ---------- %(adata)s callback Gene and lineage specific prepare callbacks. obs Sequence of observations, such as genes. lineages Sequence of genes. perform_sanity_check Whether to check if all callbacks have the correct signature. This is done by instantiating dummy model and running the function. We're assuming that the callback isn't really a pricey operation. If `None`, it is only performed for non-default callbacks. **kwargs Keyword arguments for ``callback`` when performing the sanity check. Returns ------- The created callbacks. """ def process_lineages( obs_name: str, lin_names: Optional[Union[Callable, Dict[Optional[str], Any]]] ): if lin_names is None: lin_names = _default_model_callback if callable(lin_names): # sharing the same models for all lineages for lin_name in lineages: callbacks[obs_name][lin_name] = lin_names return lin_rest_callback = ( lin_names.get("*", _default_model_callback) or _default_model_callback ) # do not pop for lin_name, cb in lin_names.items(): if lin_name == "*": continue callbacks[obs_name][lin_name] = cb if callable(lin_rest_callback): for lin_name in lineages - set(callbacks[obs_name].keys()): callbacks[obs_name][lin_name] = lin_rest_callback else: raise TypeError( f"Expected the callback for the rest of lineages to be `callable`, " f"found `{type(lin_rest_callback).__name__!r}`." ) def maybe_sanity_check(callbacks: Dict[str, Dict[str, Callable]]) -> None: if not perform_sanity_check: return from sklearn.svm import SVR logg.debug("Performing callback sanity checks") for gene in callbacks.keys(): for lineage, cb in callbacks[gene].items(): # create the model here because the callback can search the attribute dummy_model = SKLearnModel(adata, model=SVR()) try: model = cb(dummy_model, gene=gene, lineage=lineage, **kwargs) assert model is dummy_model, ( "Creation of new models is not allowed. " "Ensure that callback returns the same model." ) assert ( model.prepared ), "Model is not prepared. Ensure that callback calls `.prepare()`." assert ( model._gene == gene ), f"Callback modified the gene from `{gene!r}` to `{model._gene!r}`." assert ( model._lineage == lineage ), f"Callback modified the lineage from `{lineage!r}` to `{model._lineage!r}`." except Exception as e: raise RuntimeError( f"Callback validation failed for gene `{gene!r}` and lineage `{lineage!r}`." ) from e if callback is None: callback = _default_model_callback if perform_sanity_check is None: perform_sanity_check = callback is not _default_model_callback if callable(callback): callbacks = {o: {lin: copy(callback) for lin in lineages} for o in obs} maybe_sanity_check(callbacks) return callbacks lineages, obs = ( set(_unique_order_preserving(lineages)), set(_unique_order_preserving(obs)), ) callbacks = defaultdict(dict) if isinstance(callback, dict): for obs_name, lin_names in callback.items(): process_lineages(obs_name, lin_names) # can be specified as None obs_rest_callback = ( callback.pop("*", _default_model_callback) or _default_model_callback ) if callable(obs_rest_callback): for obs_name in obs - set(callback.keys()): process_lineages(obs_name, callback.get(obs_name, obs_rest_callback)) else: raise TypeError( f"Expected the callback for the rest of genes to be `callable`, " f"found `{type(obs_rest_callback).__name__!r}`." ) else: raise TypeError( f"Class `{type(callback).__name__!r}` must be callable` or a dictionary of callables." ) maybe_sanity_check(callbacks) return callbacks
def _create_models( model: _model_type, obs: Sequence[str], lineages: Sequence[Optional[str]] ) -> Dict[str, Dict[str, BaseModel]]: """ Create models for each gene and lineage. Parameters ---------- obs Sequence of observations, such as genes. lineages Sequence of genes. Returns ------- The created models. """ def process_lineages( obs_name: str, lin_names: Union[BaseModel, Dict[Optional[str], Any]] ): if isinstance(lin_names, BaseModel): # sharing the same models for all lineages for lin_name in lineages: models[obs_name][lin_name] = lin_names return lin_rest_model = lin_names.get("*", None) # do not pop for lin_name, mod in lin_names.items(): if lin_name == "*": continue models[obs_name][lin_name] = copy(mod) if lin_rest_model is not None: for lin_name in lineages - set(models[obs_name].keys()): models[obs_name][lin_name] = copy(lin_rest_model) else: raise RuntimeError(_ERROR_INCOMPLETE_SPEC.format(" lineage ", obs_name)) if isinstance(model, BaseModel): return {o: {lin: copy(model) for lin in lineages} for o in obs} lineages, obs = ( set(_unique_order_preserving(lineages)), set(_unique_order_preserving(obs)), ) models = defaultdict(dict) if isinstance(model, dict): obs_rest_model = model.pop("*", None) for obs_name, lin_names in model.items(): process_lineages(obs_name, lin_names) if obs_rest_model is not None: for obs_name in obs - set(model.keys()): process_lineages(obs_name, model.get(obs_name, obs_rest_model)) elif set(model.keys()) != obs: raise RuntimeError(_ERROR_INCOMPLETE_SPEC.format(" ", "genes")) else: raise TypeError( f"Class `{type(model).__name__!r}` must be of type `cellrank.ul.BaseModel` or a dictionary of such models." ) return models
def cluster_lineage( adata: AnnData, model: _input_model_type, genes: Sequence[str], lineage: str, backward: bool = False, time_range: _time_range_type = None, clusters: Optional[Sequence[str]] = None, n_points: int = 200, time_key: str = "latent_time", norm: bool = True, recompute: bool = False, callback: _callback_type = None, ncols: int = 3, sharey: Union[str, bool] = False, key: Optional[str] = None, random_state: Optional[int] = None, use_leiden: bool = False, show_progress_bar: bool = True, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}), neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}), clustering_kwargs: Dict = MappingProxyType({}), return_models: bool = False, **kwargs, ) -> Optional[_return_model_type]: """ Cluster gene expression trends within a lineage and plot the clusters. This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive development along a given lineage. Consider running this function on a subset of genes which are potential lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`. Parameters ---------- %(adata)s %(model)s %(genes)s lineage Name of the lineage for which to cluster the genes. %(backward)s %(time_ranges)s clusters Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when plotting previously computed clusters. n_points Number of points used for prediction. time_key Key in ``adata.obs`` where the pseudotime is stored. norm Whether to z-normalize each trend to have zero mean, unit variance. recompute If `True`, recompute the clustering, otherwise try to find already existing one. %(model_callback)s ncols Number of columns for the plot. sharey Whether to share y-axis across multiple plots. key Key in ``adata.uns`` where to save the results. If `None`, it will be saved as ``lineage_{lineage}_trend`` . random_state Random seed for reproducibility. use_leiden Whether to use :func:`scanpy.tl.leiden` for clustering or :func:`scanpy.tl.louvain`. %(parallel)s %(plotting)s pca_kwargs Keyword arguments for :func:`scanpy.pp.pca`. neighbors_kwargs Keyword arguments for :func:`scanpy.pp.neighbors`. clustering_kwargs Keyword arguments for :func:`scanpy.tl.louvain` or :func:`scanpy.tl.leiden`. %(return_models)s **kwargs: Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(plots_or_returns_models)s Also updates ``adata.uns`` with the following: - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of shape `(n_genes, n_points)` containing the clustered genes. """ import scanpy as sc from anndata import AnnData as _AnnData lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") _ = adata.obsm[lineage_key][lineage] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False)) if key is None: key = f"lineage_{lineage}_trend" if recompute or key not in adata.uns: kwargs["backward"] = backward kwargs["time_key"] = time_key kwargs["n_test_points"] = n_points models = _create_models(model, genes, [lineage]) all_models, models, genes, _ = _fit_bulk( models, _create_callbacks(adata, callback, genes, [lineage], **kwargs), genes, lineage, time_range, return_models=True, # always return (better error messages) filter_all_failed=True, parallel_kwargs={ "show_progress_bar": show_progress_bar, "n_jobs": _get_n_cores(n_jobs, len(genes)), "backend": _get_backend(models, backend), }, **kwargs, ) # `n_genes, n_test_points` trends = np.vstack( [model[lineage].y_test for model in models.values()]).T if norm: logg.debug("Normalizing trends") _ = StandardScaler(copy=False).fit_transform(trends) trends = _AnnData(trends.T) trends.obs_names = genes # sanity check if trends.n_obs != len(genes): raise RuntimeError( f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`." ) if trends.n_vars != n_points: raise RuntimeError( f"Expected to find `{n_points}` points, found `{trends.n_vars}`." ) random_state = np.random.mtrand.RandomState(random_state).randint( 2**16) pca_kwargs = dict(pca_kwargs) pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1) pca_kwargs.setdefault("random_state", random_state) sc.pp.pca(trends, **pca_kwargs) neighbors_kwargs = dict(neighbors_kwargs) neighbors_kwargs.setdefault("random_state", random_state) sc.pp.neighbors(trends, **neighbors_kwargs) clustering_kwargs = dict(clustering_kwargs) clustering_kwargs["key_added"] = "clusters" clustering_kwargs.setdefault("random_state", random_state) try: if use_leiden: sc.tl.leiden(trends, **clustering_kwargs) else: sc.tl.louvain(trends, **clustering_kwargs) except ImportError as e: logg.warning(str(e)) if use_leiden: sc.tl.louvain(trends, **clustering_kwargs) else: sc.tl.leiden(trends, **clustering_kwargs) logg.info(f"Saving data to `adata.uns[{key!r}]`") adata.uns[key] = trends else: all_models = None logg.info(f"Loading data from `adata.uns[{key!r}]`") trends = adata.uns[key] if "clusters" not in trends.obs: raise KeyError( "Unable to find the clustering in `trends.obs['clusters']`.") if clusters is None: clusters = trends.obs["clusters"].cat.categories for c in clusters: if c not in trends.obs["clusters"].cat.categories: raise ValueError( f"Invalid cluster name `{c!r}`. " f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`." ) nrows = int(np.ceil(len(clusters) / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * 10, nrows * 10) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) j = 0 for j, (ax, c) in enumerate(zip(axes, clusters)): # noqa data = trends[trends.obs["clusters"] == c].X mean, sd = np.mean(data, axis=0), np.var(data, axis=0) sd = np.sqrt(sd) for i in range(data.shape[0]): ax.plot(data[i], color="gray", lw=0.5) ax.plot(mean, lw=2, color="black") ax.plot(mean - sd, lw=1.5, color="black", linestyle="--") ax.plot(mean + sd, lw=1.5, color="black", linestyle="--") ax.fill_between(range(len(mean)), mean - sd, mean + sd, color="black", alpha=0.1) ax.set_title(f"Cluster {c}") ax.set_xticks([]) if not sharey: ax.set_yticks([]) for j in range(j + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save) if return_models: return all_models
def heatmap( adata: AnnData, model: _input_model_type, genes: Sequence[str], lineages: Optional[Union[str, Sequence[str]]] = None, backward: bool = False, mode: str = HeatmapMode.LINEAGES.s, time_key: str = "latent_time", time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None, callback: _callback_type = None, cluster_key: Optional[Union[str, Sequence[str]]] = None, show_absorption_probabilities: bool = False, cluster_genes: bool = False, keep_gene_order: bool = False, scale: bool = True, n_convolve: Optional[int] = 5, show_all_genes: bool = False, cbar: bool = True, lineage_height: float = 0.33, fontsize: Optional[float] = None, xlabel: Optional[str] = None, cmap: mcolors.ListedColormap = cm.viridis, dendrogram: bool = True, return_genes: bool = False, return_models: bool = False, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, show_progress_bar: bool = True, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs, ) -> Optional[Union[Dict[str, pd.DataFrame], Tuple[_return_model_type, Dict[ str, pd.DataFrame]]]]: """ Plot a heatmap of smoothed gene expression along specified lineages. Parameters ---------- %(adata)s %(model)s %(genes)s lineages Names of the lineages for which to plot. If `None`, plot all lineages. %(backward)s mode Valid options are: - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``. - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``. time_key Key in ``adata.obs`` where the pseudotime is stored. %(time_ranges)s %(model_callback)s cluster_key Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap. Only available when ``mode={m.LINEAGES.s!r}``. show_absorption_probabilities Whether to also plot absorption probabilities alongside the smoothed expression. Only available when ``mode={m.LINEAGES.s!r}``. cluster_genes Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``. keep_gene_order Whether to keep the gene order for later lineages after the first was sorted. Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``. scale Whether to normalize the gene expression `0-1` range. n_convolve Size of the convolution window when smoothing absorption probabilities. show_all_genes Whether to show all genes on y-axis. cbar Whether to show the colorbar. lineage_height Height of a bar when ``mode={m.GENES.s!r}``. fontsize Size of the title's font. xlabel Label on the x-axis. If `None`, it is determined based on ``time_key``. cmap Colormap to use when visualizing the smoothed expression. dendrogram Whether to show dendrogram when ``cluster_genes=True``. return_genes Whether to return the sorted or clustered genes. Only available when ``mode={m.LINEAGES.s!r}``. %(return_models)s %(parallel)s %(plotting)s kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(plots_or_returns_models)s :class:`pandas.DataFrame` If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame` containing the clustered or sorted genes. """ import seaborn as sns def find_indices(series: pd.Series, values) -> Tuple[Any]: def find_nearest(array: np.ndarray, value: float) -> int: ix = np.searchsorted(array, value, side="left") if ix > 0 and (ix == len(array) or fabs(value - array[ix - 1]) < fabs(value - array[ix])): return ix - 1 return ix series = series[np.argsort(series.values)] return tuple(series[[find_nearest(series.values, v) for v in values]].index) def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray: time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) lin = adata[ixs, :].obsm[lineage_key][lname] lin = lin.X.copy().squeeze() if n_convolve is not None: lin = convolve(lin, np.ones(n_convolve) / n_convolve, mode="nearest") return lin def create_col_colors(lname: str, rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]: color = adata.obsm[lineage_key][lname].colors[0] lin = subset_lineage(lname, rng) h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color)) end_color = mcolors.hsv_to_rgb([h, 1, v]) lineage_cmap = mcolors.LinearSegmentedColormap.from_list( "lineage_cmap", ["#ffffff", end_color], N=len(rng)) norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin)) scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm) return ( np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]), lineage_cmap, norm, ) def create_col_categorical_color(cluster_key: str, rng: np.ndarray) -> np.ndarray: if not is_categorical_dtype(adata.obs[cluster_key]): raise TypeError( f"Expected `adata.obs[{cluster_key!r}]` to be categorical, " f"found `{adata.obs[cluster_key].dtype.name!r}`.") color_key = f"{cluster_key}_colors" if color_key not in adata.uns: logg.warning( f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors" ) colors = _create_categorical_colors( len(adata.obs[cluster_key].cat.categories)) adata.uns[color_key] = colors else: colors = adata.uns[color_key] time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors)) return np.array([ mcolors.to_hex(mapper[v]) for v in adata[ixs, :].obs[cluster_key].values ]) def create_cbar( ax, x_delta: float, cmap: Cmap, norm: Norm, label: Optional[str] = None, ) -> Ax: cax = inset_axes( ax, width="1%", height="100%", loc="lower right", bbox_to_anchor=(x_delta, 0, 1, 1), bbox_transform=ax.transAxes, ) _ = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, label=label, ticks=np.linspace(norm.vmin, norm.vmax, 5), ) return cax @valuedispatch def _plot_heatmap(_mode: HeatmapMode) -> Fig: pass @_plot_heatmap.register(HeatmapMode.GENES) def _() -> Tuple[Fig, None]: def color_fill_rec(ax, xs, y1, y2, colors=None, cmap=cmap, **kwargs) -> None: colors = colors if cmap is None else cmap(colors) x = 0 for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)): dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2]) ax.add_patch( plt.Rectangle((x, y1), dx, y2 - y1, color=color, ec=color, **kwargs)) ax.plot(x, y2, lw=0) fig, axes = plt.subplots( nrows=len(genes) + show_absorption_probabilities, figsize=(12, len(genes) + len(lineages) * lineage_height) if figsize is None else figsize, dpi=dpi, constrained_layout=True, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) if show_absorption_probabilities: data["absorption probability"] = data[next(iter(data.keys()))] for ax, (gene, models) in zip(axes, data.items()): if scale: vmin, vmax = 0, 1 else: c = np.array([m.y_test for m in models.values()]) vmin, vmax = np.nanmin(c), np.nanmax(c) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) ix = 0 ys = [ix] if gene == "absorption probability": norm = mcolors.Normalize(vmin=0, vmax=1) for ln, x in ((ln, m.x_test) for ln, m in models.items()): y = np.ones_like(x) c = subset_lineage(ln, x.squeeze()) color_fill_rec(ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)) ix += lineage_height ys.append(ix) else: for x, c in ((m.x_test, m.y_test) for m in models.values()): y = np.ones_like(x) c = _min_max_scale(c) if scale else c color_fill_rec(ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)) ix += lineage_height ys.append(ix) xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.min(xs), np.max(xs) ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS)) ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2) ax.spines["left"].set_position( ("data", 0) ) # move the left spine to the rectangles to get nicer yticks ax.set_yticklabels(models.keys(), ha="right") ax.set_title(gene, fontdict={"fontsize": fontsize}) ax.set_ylabel("lineage") for pos in ["top", "bottom", "left", "right"]: ax.spines[pos].set_visible(False) if cbar: cax, _ = mpl.colorbar.make_axes(ax) _ = mpl.colorbar.ColorbarBase( cax, ticks=np.linspace(vmin, vmax, 5), norm=norm, cmap=cmap, label="value" if gene == "absorption probability" else "scaled expression" if scale else "expression", ) ax.tick_params( top=False, bottom=False, left=True, right=False, labelleft=True, labelbottom=False, ) ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f")) ax.tick_params( top=False, bottom=True, left=True, right=False, labelleft=True, labelbottom=True, ) ax.set_xlabel(xlabel) return fig, None @_plot_heatmap.register(HeatmapMode.LINEAGES) def _() -> Tuple[List[Fig], pd.DataFrame]: data_t = defaultdict(dict) # transpose for gene, lns in data.items(): for ln, y in lns.items(): data_t[ln][gene] = y figs = [] gene_order = None sorted_genes = pd.DataFrame() if return_genes else None for lname, models in data_t.items(): xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.nanmin(xs), np.nanmax(xs) df = pd.DataFrame([m.y_test for m in models.values()], index=models.keys()) df.index.name = "genes" if not cluster_genes: if gene_order is not None: df = df.loc[gene_order] else: max_sort = np.argsort( np.argmax(df.apply(_min_max_scale, axis=1).values, axis=1)) df = df.iloc[max_sort, :] if keep_gene_order: gene_order = df.index cat_colors = None if cluster_key is not None: cat_colors = np.stack( [ create_col_categorical_color( c, np.linspace(x_min, x_max, df.shape[1])) for c in cluster_key ], axis=0, ) if show_absorption_probabilities: col_colors, col_cmap, col_norm = create_col_colors( lname, np.linspace(x_min, x_max, df.shape[1])) if cat_colors is not None: col_colors = np.vstack([cat_colors, col_colors[None, :]]) else: col_colors, col_cmap, col_norm = cat_colors, None, None row_cluster = cluster_genes and df.shape[0] > 1 show_clust = row_cluster and dendrogram g = sns.clustermap( df, cmap=cmap, figsize=(10, min(len(genes) / 8 + 1, 10)) if figsize is None else figsize, xticklabels=False, row_cluster=row_cluster, col_colors=col_colors, colors_ratio=0, col_cluster=False, cbar_pos=None, yticklabels=show_all_genes or "auto", standard_scale=0 if scale else None, ) if cbar: cax = create_cbar( g.ax_heatmap, 0.1, cmap=cmap, norm=mcolors.Normalize( vmin=0 if scale else np.min(df.values), vmax=1 if scale else np.max(df.values), ), label="scaled expression" if scale else "expression", ) g.fig.add_axes(cax) if col_cmap is not None and col_norm is not None: cax = create_cbar( g.ax_heatmap, 0.25, cmap=col_cmap, norm=col_norm, label="absorption probability", ) g.fig.add_axes(cax) if g.ax_col_colors: main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap) n_bars = show_absorption_probabilities + ( len(cluster_key) if cluster_key is not None else 0) _set_ax_height_to_cm( g.fig, g.ax_col_colors, height=min( 5, max(n_bars * main_bbox.height / len(df), 0.25 * n_bars)), ) g.ax_col_colors.set_title(lname, fontdict={"fontsize": fontsize}) else: g.ax_heatmap.set_title(lname, fontdict={"fontsize": fontsize}) g.ax_col_dendrogram.set_visible( False) # gets rid of top free space g.ax_heatmap.yaxis.tick_left() g.ax_heatmap.yaxis.set_label_position("right") g.ax_heatmap.set_xlabel(xlabel) g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS)) g.ax_heatmap.set_xticklabels( list( map(lambda n: round(n, 3), np.linspace(x_min, x_max, _N_XTICKS)))) if show_clust: # robustly show dendrogram, because gene names can be long g.ax_row_dendrogram.set_visible(True) dendro_box = g.ax_row_dendrogram.get_position() pad = 0.005 bb = g.ax_heatmap.yaxis.get_tightbbox( g.fig.canvas.get_renderer()).transformed( g.fig.transFigure.inverted()) dendro_box.x0 = bb.x0 - dendro_box.width - pad dendro_box.x1 = bb.x0 - pad g.ax_row_dendrogram.set_position(dendro_box) else: g.ax_row_dendrogram.set_visible(False) if return_genes: sorted_genes[lname] = (df.index[g.dendrogram_row.reordered_ind] if hasattr(g, "dendrogram_row") and g.dendrogram_row is not None else df.index) figs.append(g) return figs, sorted_genes mode = HeatmapMode(mode) lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[lineage_key].names elif isinstance(lineages, str): lineages = [lineages] lineages = _unique_order_preserving(lineages) _ = adata.obsm[lineage_key][lineages] if cluster_key is not None: if isinstance(cluster_key, str): cluster_key = [cluster_key] cluster_key = _unique_order_preserving(cluster_key) if isinstance(genes, str): genes = [genes] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False)) kwargs["backward"] = backward kwargs["time_key"] = time_key models = _create_models(model, genes, lineages) all_models, data, genes, lineages = _fit_bulk( models, _create_callbacks(adata, callback, genes, lineages, **kwargs), genes, lineages, time_range, return_models=True, # always return (better error messages) filter_all_failed=True, parallel_kwargs={ "show_progress_bar": show_progress_bar, "n_jobs": _get_n_cores(n_jobs, len(genes)), "backend": _get_backend(models, backend), }, **kwargs, ) xlabel = time_key if xlabel is None else xlabel logg.debug(f"Plotting `{mode.s!r}` heatmap") fig, genes = _plot_heatmap(mode) if save is not None and fig is not None: if not isinstance(fig, Iterable): save_fig(fig, save) elif len(fig) == 1: save_fig(fig[0], save) else: for ln, f in zip(lineages, fig): save_fig(f, os.path.join(save, f"lineage_{ln}")) if return_genes and mode == HeatmapMode.LINEAGES: return (all_models, genes) if return_models else genes elif return_models: return all_models
def log_odds( adata: AnnData, lineage_1: str, lineage_2: Optional[str] = None, time_key: str = "exp_time", backward: bool = False, keys: Optional[Union[str, Sequence[str]]] = None, threshold: Optional[Union[float, Sequence]] = None, threshold_color: str = "red", layer: Optional[str] = None, use_raw: bool = False, size: float = 2.0, cmap: str = "viridis", alpha: Optional[float] = 0.8, ncols: Optional[int] = None, fontsize: Optional[Union[float, str]] = None, xticks_step_size: Optional[int] = 1, legend_loc: Optional[str] = "best", jitter: Union[bool, float] = True, seed: Optional[int] = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, show: bool = True, **kwargs: Any, ) -> Optional[Union[Axes, Sequence[Axes]]]: """ Plot log-odds ratio between lineages. Log-odds are plotted as a function of the experimental time. Parameters ---------- %(adata)s lineage_1 The first lineage for which to compute the log-odds. lineage_2 The second lineage for which to compute the log-odds. If `None`, use the rest of the lineages. time_key Key in :attr:`anndata.AnnData.obs` containing the experimental time. %(backward)s keys Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names`. threshold Visualize whether total expression per cell is greater than ``threshold``. If a :class:`typing.Sequence`, it should be the same length as ``keys``. threshold_color Color to use when plotting thresholded expression values. layer Which layer to use to get expression values. If `None` or `'X'`, use :attr:`anndata.AnnData.X`. use_raw Whether to access :attr:`anndata.AnnData.raw`. If `True`, ``layer`` is ignored. size Size of the dots. cmap Colormap to use for continuous variables in ``keys``. alpha Alpha values for the dots. ncols Number of columns. fontsize Size of the font for the title, x- and y-label. xticks_step_size Show only every n-th ticks on x-axis. If `None`, don't show any ticks. legend_loc Position of the legend. If `None`, do not show the legend. jitter Amount of jitter to apply along x-axis. seed Seed for ``jitter`` to ensure reproducibility. %(plotting)s show If `False`, return :class:`matplotlib.pyplot.Axes` or a sequence of them. kwargs Keyword arguments for :func:`seaborn.stripplot`. Returns ------- :class:`matplotlib.pyplot.Axes` The axis object(s) if ``show=False``. %(just_plots)s """ from cellrank.tl.kernels._utils import _ensure_numeric_ordered def decorate(ax: Axes, *, title: Optional[str] = None, show_ylabel: bool = True) -> None: ax.set_xlabel(time_key, fontsize=fontsize) ax.set_title(title, fontdict={"fontsize": fontsize}) ax.set_ylabel(ylabel if show_ylabel else "", fontsize=fontsize) if xticks_step_size is None: ax.set_xticks([]) else: step = max(1, xticks_step_size) ax.set_xticks(np.arange(0, n_cats, step)) ax.set_xticklabels(df[time_key].cat.categories[::step]) def cont_palette(values: np.ndarray) -> Tuple[np.ndarray, ScalarMappable]: cm = copy(plt.get_cmap(cmap)) cm.set_bad("grey") sm = ScalarMappable(cmap=cm, norm=Normalize(vmin=np.nanmin(values), vmax=np.nanmax(values))) return np.array([to_hex(v) for v in (sm.to_rgba(values))]), sm def get_data( key: str, thresh: Optional[float] = None, ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[np.ndarray], ScalarMappable]: try: _, palette = _get_categorical_colors(adata, key) df[key] = adata.obs[key].values[mask] df[key] = df[key].cat.remove_unused_categories() try: # seaborn doesn't like numeric categories df[key] = df[key].astype(float) palette = {float(k): v for k, v in palette.items()} except ValueError: pass # otherwise seaborn plots all palette = {k: palette[k] for k in df[key].unique()} hue, thresh_mask, sm = key, None, None except TypeError: palette, hue, thresh_mask, sm = ( cont_palette(adata.obs[key].values[mask])[0], None, None, None, ) except KeyError: try: # fmt: off if thresh is None: values = adata.raw.obs_vector( key) if use_raw else adata.obs_vector(key, layer=layer) palette, sm = cont_palette(values) hue, thresh_mask = None, None else: if use_raw: values = np.asarray( adata.raw[:, key].X[mask].sum(1)).squeeze() elif layer not in (None, "X"): values = np.asarray( adata[:, key].layers[layer][mask].sum(1)).squeeze() else: values = np.asarray( adata[:, key].X[mask].sum(1)).squeeze() thresh_mask = values > thresh hue, palette, sm = None, None, None # fmt: on except KeyError as e: raise e from None return hue, palette, thresh_mask, sm np.random.seed(seed) _ = kwargs.pop("orient", None) if use_raw and adata.raw is None: logg.warning("No raw attribute set. Setting `use_raw=False`") use_raw = False # define log-odds ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if ln_key not in adata.obsm: raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.") time = _ensure_numeric_ordered(adata, time_key) order = time.cat.categories[::-1 if backward else 1] fate1 = adata.obsm[ln_key][lineage_1].X.squeeze(-1) if lineage_2 is None: fate2 = 1 - fate1 ylabel = rf"$\log{{\frac{{{lineage_1}}}{{rest}}}}$" else: fate2 = adata.obsm[ln_key][lineage_2].X.squeeze(-1) ylabel = rf"$\log{{\frac{{{lineage_1}}}{{{lineage_2}}}}}$" # fmt: off df = pd.DataFrame({ "log_odds": np.log( np.divide(fate1, fate2, where=fate2 != 0, out=np.zeros_like(fate1)) + 1e-12), time_key: time, }) mask = (fate1 != 0) & (fate2 != 0) df = df[mask] n_cats = len(df[time_key].cat.categories) # fmt: on if keys is None: if figsize is None: figsize = np.array([n_cats, n_cats * 4 / 6]) / 2 fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True) ax = sns.stripplot( time_key, "log_odds", data=df, order=order, jitter=jitter, color="k", size=size, ax=ax, **kwargs, ) decorate(ax) if save is not None: save_fig(fig, save) return None if show else ax if isinstance(keys, str): keys = (keys, ) if not len(keys): raise ValueError("No keys have been selected.") keys = _unique_order_preserving(keys) if not isinstance(threshold, Iterable): threshold = (threshold, ) * len(keys) if len(threshold) != len(keys): raise ValueError( f"Expected `threshold` to be of length `{len(keys)}`, found `{len(threshold)}`." ) ncols = max(len(keys) if ncols is None else ncols, 1) nrows = int(np.ceil(len(keys) / ncols)) if figsize is None: figsize = np.array([n_cats * ncols, n_cats * nrows * 4 / 6]) / 2 fig, axes = plt.subplots( nrows=nrows, ncols=ncols, figsize=figsize, dpi=dpi, constrained_layout=True, sharey="all", ) axes = np.ravel([axes]) i = 0 for i, (key, ax, thresh) in enumerate(zip(keys, axes, threshold)): hue, palette, thresh_mask, sm = get_data(key, thresh) show_ylabel = i % ncols == 0 ax = sns.stripplot( time_key, "log_odds", data=df if thresh_mask is None else df[~thresh_mask], hue=hue, order=order, jitter=jitter, color="black", palette=palette, size=size, alpha=alpha if alpha is not None else None if thresh_mask is None else 0.8, ax=ax, **kwargs, ) if thresh_mask is not None: sns.stripplot( time_key, "log_odds", data=df if thresh_mask is None else df[thresh_mask], hue=hue, order=order, jitter=jitter, color=threshold_color, palette=palette, size=size * 2, alpha=0.9, ax=ax, **kwargs, ) key = rf"${key} > {thresh}$" if sm is not None: cax = ax.inset_axes([1.02, 0, 0.025, 1], transform=ax.transAxes) fig.colorbar(sm, ax=ax, cax=cax) else: if legend_loc in (None, "none"): legend = ax.get_legend() if legend is not None: legend.remove() else: handles, labels = ax.get_legend_handles_labels() if len(handles): _position_legend(ax, legend_loc=legend_loc, handles=handles, labels=labels) decorate(ax, title=key, show_ylabel=show_ylabel) for ax in axes[i + 1:]: ax.remove() axes = axes[:i + 1] if save is not None: save_fig(fig, save) return None if show else axes[0] if len(axes) == 1 else axes
def gene_trends( adata: AnnData, model: _input_model_type, genes: Union[str, Sequence[str]], lineages: Optional[Union[str, Sequence[str]]] = None, backward: bool = False, data_key: str = "X", time_key: str = "latent_time", transpose: bool = False, time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None, callback: _callback_type = None, conf_int: Union[bool, float] = True, same_plot: bool = False, hide_cells: bool = False, perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float, float]]]] = None, lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None, abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis, cell_color: Optional[str] = None, cell_alpha: float = 0.6, lineage_alpha: float = 0.2, size: float = 15, lw: float = 2, cbar: bool = True, margins: float = 0.015, sharex: Optional[Union[str, bool]] = None, sharey: Optional[Union[str, bool]] = None, gene_as_title: Optional[bool] = None, legend_loc: Optional[str] = "best", obs_legend_loc: Optional[str] = "best", ncols: int = 2, suptitle: Optional[str] = None, return_models: bool = False, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, show_progress_bar: bool = True, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, plot_kwargs: Mapping = MappingProxyType({}), **kwargs, ) -> Optional[_return_model_type]: """ Plot gene expression trends along lineages. Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression, where we take the lineage weights into account in the loss function. Parameters ---------- %(adata)s %(model)s %(genes)s lineages Names of the lineages to plot. If `None`, plot all lineages. %(backward)s data_key Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored. time_key Key in ``adata.obs`` where the pseudotime is stored. %(time_ranges)s transpose If ``same_plot=True``, group the trends by ``lineages`` instead of ``genes``. This enforces ``hide_cells=True``. If ``same_plot=False``, show ``lineages`` in rows and ``genes`` in columns. %(model_callback)s conf_int Whether to compute and show confidence interval. If the :paramref:`model` is :class:`cellrank.ul.models.GAMR`, it can also specify the confidence level, the default is `0.95`. same_plot Whether to plot all lineages for each gene in the same plot. hide_cells If `True`, hide all cells. perc Percentile for colors. Valid values are in interval `[0, 100]`. This can improve visualization. Can be specified individually for each lineage. lineage_cmap Categorical colormap to use when coloring in the lineages. If `None` and ``same_plot``, use the corresponding colors in ``adata.uns``, otherwise use `'black'`. abs_prob_cmap Continuous colormap to use when visualizing the absorption probabilities for each lineage. Only used when ``same_plot=False``. cell_color Key in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names` used for coloring the cells. cell_alpha Alpha channel for cells. lineage_alpha Alpha channel for lineage confidence intervals. size Size of the points. lw Line width of the smoothed values. cbar Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``. margins Margins around the plot. sharex Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`. sharey Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`. gene_as_title Whether to show gene names as titles instead on y-axis. legend_loc Location of the legend displaying lineages. Only used when `same_plot=True`. obs_legend_loc Location of the legend when ``cell_color`` corresponds to a categorical variable. ncols Number of columns of the plot when plotting multiple genes. Only used when ``same_plot=True``. suptitle Suptitle of the figure. %(return_models)s %(parallel)s %(plotting)s plot_kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`. kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(plots_or_returns_models)s """ if isinstance(genes, str): genes = [genes] genes = _unique_order_preserving(genes) if data_key != "obs": _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False)) else: _check_collection(adata, genes, "obs", use_raw=kwargs.get("use_raw", False)) ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if ln_key not in adata.obsm: raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[ln_key].names elif isinstance(lineages, str): lineages = [lineages] elif all(ln is None for ln in lineages): # no lineage, all the weights are 1 lineages = [None] cbar = False logg.debug("All lineages are `None`, setting the weights to `1`") lineages = _unique_order_preserving(lineages) if isinstance(time_range, (tuple, float, int, type(None))): time_range = [time_range] * len(lineages) elif len(time_range) != len(lineages): raise ValueError( f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`." ) kwargs["time_key"] = time_key kwargs["data_key"] = data_key kwargs["backward"] = backward kwargs["conf_int"] = conf_int # prepare doesnt take or need this models = _create_models(model, genes, lineages) all_models, models, genes, lineages = _fit_bulk( models, _create_callbacks(adata, callback, genes, lineages, **kwargs), genes, lineages, time_range, return_models=True, filter_all_failed=False, parallel_kwargs={ "show_progress_bar": show_progress_bar, "n_jobs": _get_n_cores(n_jobs, len(genes)), "backend": _get_backend(models, backend), }, **kwargs, ) lineages = sorted(lineages) tmp = adata.obsm[ln_key][lineages].colors if lineage_cmap is None and not transpose: lineage_cmap = tmp plot_kwargs = dict(plot_kwargs) plot_kwargs["obs_legend_loc"] = obs_legend_loc if transpose: all_models = pd.DataFrame(all_models).T.to_dict() models = pd.DataFrame(models).T.to_dict() genes, lineages = lineages, genes hide_cells = same_plot or hide_cells else: # information overload otherwise plot_kwargs["lineage_probability"] = False plot_kwargs["lineage_probability_conf_int"] = False tmp = pd.DataFrame(models).T.astype(bool) start_rows = np.argmax(tmp.values, axis=0) end_rows = tmp.shape[0] - np.argmax(tmp[::-1].values, axis=0) - 1 if same_plot: gene_as_title = True if gene_as_title is None else gene_as_title sharex = "all" if sharex is None else sharex if sharey is None: sharey = "row" if plot_kwargs.get("lineage_probability", False) else "none" ncols = len(genes) if ncols >= len(genes) else ncols nrows = int(np.ceil(len(genes) / ncols)) else: gene_as_title = False if gene_as_title is None else gene_as_title sharex = "col" if sharex is None else sharex if sharey is None: sharey = ("row" if not hide_cells or plot_kwargs.get( "lineage_probability", False) else "none") nrows = len(genes) ncols = len(lineages) plot_kwargs = dict(plot_kwargs) if plot_kwargs.get("xlabel", None) is None: plot_kwargs["xlabel"] = time_key fig, axes = plt.subplots( nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize, tight_layout=True, dpi=dpi, ) axes = np.reshape(axes, (nrows, ncols)) cnt = 0 plot_kwargs["obs_legend_loc"] = None if same_plot else obs_legend_loc logg.info("Plotting trends") for row in range(len(axes)): for col in range(len(axes[row])): if cnt >= len(genes): break gene = genes[cnt] if (same_plot and plot_kwargs.get("lineage_probability", False) and transpose): lpc = adata.obsm[ln_key][gene].colors[0] else: lpc = None if same_plot: plot_kwargs["obs_legend_loc"] = (obs_legend_loc if row == 0 and col == len(axes[0]) - 1 else None) _trends_helper( models, gene=gene, lineage_names=lineages, transpose=transpose, same_plot=same_plot, hide_cells=hide_cells, perc=perc, lineage_cmap=lineage_cmap, abs_prob_cmap=abs_prob_cmap, lineage_probability_color=lpc, cell_color=cell_color, alpha=cell_alpha, lineage_alpha=lineage_alpha, size=size, lw=lw, cbar=cbar, margins=margins, sharey=sharey, gene_as_title=gene_as_title, legend_loc=legend_loc, figsize=figsize, fig=fig, axes=axes[row, col] if same_plot else axes[cnt], show_ylabel=col == 0, show_lineage=same_plot or (cnt == start_rows), show_xticks_and_label=((row + 1) * ncols + col >= len(genes)) if same_plot else (cnt == end_rows), **plot_kwargs, ) # plot legend on the 1st plot cnt += 1 if not same_plot: plot_kwargs["obs_legend_loc"] = None if same_plot and (col != ncols): for ax in np.ravel(axes)[cnt:]: ax.remove() fig.suptitle(suptitle, y=1.05) if save is not None: save_fig(fig, save) if return_models: return all_models
def cluster_fates( adata: AnnData, mode: str = ClusterFatesMode.PAGA_PIE.s, backward: bool = False, lineages: Optional[Union[str, Sequence[str]]] = None, cluster_key: Optional[str] = "clusters", clusters: Optional[Union[str, Sequence[str]]] = None, basis: Optional[str] = None, cbar: bool = True, ncols: Optional[int] = None, sharey: bool = False, fmt: str = "0.2f", xrot: float = 90, legend_kwargs: Mapping[str, Any] = MappingProxyType({"loc": "best"}), figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs, ) -> None: """ Plot aggregate lineage probabilities at a cluster level. This can be used to investigate how likely a certain cluster is to go to the %(terminal)s states,or in turn to have descended from the %(initial)s states. For mode `{m.PAGA.s!r}` and `{m.PAGA_PIE.s!r}`, we use *PAGA*, see [Wolf19]_. Parameters ---------- %(adata)s mode Type of plot to show. Valid options are: - `{m.BAR.s!r}` - barplot, one panel per cluster. - `{m.PAGA.s!r}` - scanpy's PAGA, one per %(initial_or_terminal)s state, colored in by fate. - `{m.PAGA_PIE.s!r}` - scanpy's PAGA with pie charts indicating aggregated fates. - `{m.VIOLIN.s!r}` - violin plots, one per %(initial_or_terminal)s state. - `{m.HEATMAP.s!r}` - a heatmap, showing average fates per cluster. - `{m.CLUSTERMAP.s!r}` - same as a heatmap, but with a dendrogram. %(backward)s lineages Lineages for which to visualize absorption probabilities. If `None`, use all lineages. cluster_key Key in ``adata.obs`` containing the clusters. clusters Clusters to visualize. If `None`, all clusters will be plotted. basis Basis for scatterplot to use when ``mode={m.PAGA_PIE.s!r}``. If `None`, don't show the scatterplot. cbar Whether to show colorbar when ``mode={m.PAGA_PIE.s!r}``. ncols Number of columns when ``mode={m.BAR.s!r}`` or ``mode={m.PAGA.s!r}``. sharey Whether to share y-axis when ``mode={m.BAR.s!r}``. fmt Format when using ``mode={m.HEATMAP.s!r}`` or ``mode={m.CLUSTERMAP.s!r}``. xrot Rotation of the labels on the x-axis. figsize Size of the figure. legend_kwargs Keyword arguments for :func:`matplotlib.axes.Axes.legend`, such as `'loc'` for legend position. For ``mode={m.PAGA_PIE.s!r}`` and ``basis='...'``, this controls the placement of the absorption probabilities legend. %(plotting)s **kwargs Keyword arguments for :func:`scvelo.pl.paga`, :func:`scanpy.pl.violin` or :func:`matplotlib.pyplot.bar`, depending on the value of ``mode``. Returns ------- %(just_plots)s """ from scanpy.plotting import violin from scvelo.plotting import paga from seaborn import heatmap, clustermap @valuedispatch def plot(mode: ClusterFatesMode, *_args, **_kwargs): raise NotImplementedError(mode.value) @plot.register(ClusterFatesMode.BAR) def _(): cols = 4 if ncols is None else ncols n_rows = ceil(len(clusters) / cols) fig = plt.figure(None, (3.5 * cols, 5 * n_rows) if figsize is None else figsize, dpi=dpi) fig.tight_layout() gs = plt.GridSpec(n_rows, cols, figure=fig, wspace=0.5, hspace=0.5) ax = None colors = list(adata.obsm[lk][:, lin_names].colors) for g, k in zip(gs, d.keys()): current_ax = fig.add_subplot(g, sharey=ax) current_ax.bar( x=np.arange(len(lin_names)), height=d[k][0], color=colors, yerr=d[k][1], ecolor="black", capsize=10, **kwargs, ) if sharey: ax = current_ax current_ax.set_xticks(np.arange(len(lin_names))) current_ax.set_xticklabels(lin_names, rotation=xrot) if not is_all: current_ax.set_xlabel(points) current_ax.set_ylabel("absorption probability") current_ax.set_title(k) return fig @plot.register(ClusterFatesMode.PAGA) def _(): kwargs["save"] = None kwargs["show"] = False if "cmap" not in kwargs: kwargs["cmap"] = cm.viridis cols = len(lin_names) if ncols is None else ncols nrows = ceil(len(lin_names) / cols) fig, axes = plt.subplots( nrows, cols, figsize=(7 * cols, 4 * nrows) if figsize is None else figsize, constrained_layout=True, dpi=dpi, ) # fig.tight_layout() can't use this because colorbar.make_axes fails i = 0 axes = [axes] if not isinstance(axes, np.ndarray) else np.ravel(axes) vmin, vmax = np.inf, -np.inf if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key for i, (ax, lineage_name) in enumerate(zip(axes, lin_names)): colors = [v[0][i] for v in d.values()] kwargs["ax"] = ax kwargs["colors"] = tuple(colors) kwargs["title"] = f"{dir_prefix} {lineage_name}" vmin = np.min(colors + [vmin]) vmax = np.max(colors + [vmax]) paga(adata, **kwargs) if cbar: norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) cax, _ = mpl.colorbar.make_axes(ax, aspect=60) _ = mpl.colorbar.ColorbarBase( cax, ticks=np.linspace(norm.vmin, norm.vmax, 5), norm=norm, cmap=kwargs["cmap"], label="average absorption probability", ) for ax in axes[i + 1:]: # noqa ax.remove() return fig @plot.register(ClusterFatesMode.PAGA_PIE) def _(): colors = list(adata.obsm[lk][:, lin_names].colors) colors = { i: odict(zip(colors, mean)) for i, (mean, _) in enumerate(d.values()) } fig, ax = plt.subplots(figsize=figsize, dpi=dpi) fig.tight_layout() kwargs["ax"] = ax kwargs["show"] = False kwargs["colorbar"] = False # has to be disabled kwargs["show"] = False kwargs["node_colors"] = colors kwargs.pop("save", None) # we will handle saving kwargs["transitions"] = kwargs.get("transitions", "transitions_confidence") if "legend_loc" in kwargs: orig_ll = kwargs["legend_loc"] if orig_ll != "on data": kwargs["legend_loc"] = "none" # we will handle legend else: orig_ll = None kwargs["legend_loc"] = "on data" if basis is not None: kwargs["basis"] = basis kwargs["scatter_flag"] = True kwargs["color"] = cluster_key ax = paga(adata, **kwargs) ax.set_title(kwargs.get("title", cluster_key)) if basis is not None and orig_ll not in ("none", "on data", None): handles = [] for cluster_name, color in zip( adata.obs[f"{cluster_key}"].cat.categories, adata.uns[f"{cluster_key}_colors"], ): handles += [ax.scatter([], [], label=cluster_name, c=color)] first_legend = _position_legend( ax, legend_loc=orig_ll, handles=handles, **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=cluster_key, ) fig.add_artist(first_legend) if legend_kwargs.get("loc", None) not in ("none", "on data", None): # we need to use these, because scvelo can have its own handles and # they would be plotted here handles = [] for lineage_name, color in zip(lin_names, colors[0].keys()): handles += [ax.scatter([], [], label=lineage_name, c=color)] if len(colors[0].keys()) != len(adata.obsm[lk].names): handles += [ax.scatter([], [], label="Rest", c="grey")] second_legend = _position_legend( ax, legend_loc=legend_kwargs["loc"], handles=handles, **{k: v for k, v in legend_kwargs.items() if k != "loc"}, title=points, ) fig.add_artist(second_legend) return fig @plot.register(ClusterFatesMode.VIOLIN) def _(): kwargs.pop("ax", None) kwargs.pop("keys", None) kwargs.pop("save", None) # we will handle saving kwargs["show"] = False kwargs["groupby"] = cluster_key kwargs["rotation"] = xrot cols = len(lin_names) if ncols is None else ncols nrows = ceil(len(lin_names) / cols) fig, axes = plt.subplots( nrows, cols, figsize=(6 * cols, 4 * nrows) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) fig.tight_layout() fig.subplots_adjust(wspace=0.2, hspace=0.3) if not isinstance(axes, np.ndarray): axes = [axes] axes = np.ravel(axes) with RandomKeys(adata, len(lin_names), where="obs") as keys: _i = 0 for _i, (name, key, ax) in enumerate(zip(lin_names, keys, axes)): adata.obs[key] = adata.obsm[lk][name].X ax.set_title(f"{dir_prefix} {name}") violin(adata, ylabel="absorption probability", keys=key, ax=ax, **kwargs) for ax in axes[_i + 1:]: # noqa ax.remove() return fig def plot_violin_no_cluster_key(): from anndata import AnnData as _AnnData kwargs.pop("ax", None) kwargs.pop("keys", None) # don't care kwargs.pop("save", None) kwargs["show"] = False kwargs["groupby"] = points kwargs["xlabel"] = None kwargs["rotation"] = xrot data = np.ravel(adata.obsm[lk].X.T)[..., np.newaxis] tmp = _AnnData(csr_matrix(data.shape, dtype=np.float32)) tmp.obs["absorption probability"] = data tmp.obs[points] = (pd.Series( np.concatenate([[f"{dir_prefix.lower()} {n}"] * adata.n_obs for n in adata.obsm[lk].names ])).astype("category").values) tmp.obs[points].cat.reorder_categories( [f"{dir_prefix.lower()} {n}" for n in adata.obsm[lk].names], inplace=True) tmp.uns[f"{points}_colors"] = adata.obsm[lk].colors fig, ax = plt.subplots(figsize=figsize if figsize is not None else (8, 6), dpi=dpi) ax.set_title(points.capitalize()) violin(tmp, keys=["absorption probability"], ax=ax, **kwargs) return fig @plot.register(ClusterFatesMode.HEATMAP) def _(): data = pd.DataFrame([mean for mean, _ in d.values()], columns=lin_names, index=clusters).T title = kwargs.pop("title", "average fate per cluster") vmin, vmax = data.values.min(), data.values.max() cbar_kws = { "label": "probability", "ticks": np.linspace(vmin, vmax, 5), "format": "%.3f", } kwargs.setdefault("cmap", "viridis") if use_clustermap: kwargs["cbar_pos"] = (0, 0.9, 0.025, 0.15) if cbar else None max_size = float(max(data.shape)) g = clustermap( data, annot=True, vmin=vmin, vmax=vmax, fmt=fmt, row_colors=adata.obsm[lk][lin_names].colors, dendrogram_ratio=( 0.15 * data.shape[0] / max_size, 0.15 * data.shape[1] / max_size, ), cbar_kws=cbar_kws, figsize=figsize, **kwargs, ) g.ax_heatmap.set_xlabel(cluster_key) g.ax_heatmap.set_ylabel("lineage") g.ax_col_dendrogram.set_title(title) fig = g.fig g = g.ax_heatmap else: fig, ax = plt.subplots(figsize=figsize, dpi=dpi) g = heatmap( data, vmin=vmin, vmax=vmax, annot=True, fmt=fmt, cbar=cbar, cbar_kws=cbar_kws, ax=ax, **kwargs, ) ax.set_title(title) ax.set_xlabel(cluster_key) ax.set_ylabel("lineage") g.set_xticklabels(g.get_xticklabels(), rotation=xrot) g.set_yticklabels(g.get_yticklabels(), rotation=0) return fig mode = ClusterFatesMode(mode) if cluster_key is not None: if cluster_key not in adata.obs: raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") elif mode not in (mode.BAR, mode.VIOLIN): raise ValueError( f"Not specifying cluster key is only available for modes " f"`{ClusterFatesMode.BAR!r}` and `{ClusterFatesMode.VIOLIN!r}`, found `mode={mode!r}`." ) if backward: lk = AbsProbKey.BACKWARD.s points = TerminalStatesPlot.BACKWARD.s dir_prefix = DirPrefix.BACKWARD.s else: lk = AbsProbKey.FORWARD.s points = TerminalStatesPlot.FORWARD.s dir_prefix = DirPrefix.FORWARD.s if cluster_key is not None: is_all = False if clusters is not None: if isinstance(clusters, str): clusters = [clusters] clusters = _unique_order_preserving(clusters) if mode in (mode.PAGA, mode.PAGA_PIE): logg.debug( f"Setting `clusters` to all available ones because of `mode={mode!r}`" ) clusters = list(adata.obs[cluster_key].cat.categories) else: for cname in clusters: if cname not in adata.obs[cluster_key].cat.categories: raise KeyError( f"Cluster `{cname!r}` not found in `adata.obs[{cluster_key!r}]`." ) else: clusters = list(adata.obs[cluster_key].cat.categories) else: is_all = True clusters = [points] if lk not in adata.obsm: raise KeyError(f"Lineage key `{lk!r}` not found in `adata.obsm`.") if lineages is not None: if isinstance(lineages, str): lineages = [lineages] lin_names = _unique_order_preserving(lineages) else: # must be list for `sc.pl.violin`, else cats str lin_names = list(adata.obsm[lk].names) _ = adata.obsm[lk][lin_names] if mode == mode.VIOLIN and not is_all: adata = adata[np.isin(adata.obs[cluster_key], clusters)].copy() d = odict() for name in clusters: mask = (np.ones((adata.n_obs, ), dtype=np.bool) if is_all else (adata.obs[cluster_key] == name).values) mask = np.array(mask, dtype=np.bool) data = adata.obsm[lk][mask, lin_names].X mean = np.nanmean(data, axis=0) std = np.nanstd(data, axis=0) / np.sqrt(data.shape[0]) d[name] = [mean, std] logg.debug(f"Plotting in mode `{mode!r}`") use_clustermap = False if mode == mode.CLUSTERMAP: use_clustermap = True mode = mode.HEATMAP elif (mode in (ClusterFatesMode.PAGA, ClusterFatesMode.PAGA_PIE) and "paga" not in adata.uns): raise KeyError("Compute PAGA first as `scvelo.tl.paga()`.") fig = (plot_violin_no_cluster_key() if mode == ClusterFatesMode.VIOLIN and cluster_key is None else plot(mode)) if save is not None: save_fig(fig, save) fig.show()
def cluster_lineage( adata: AnnData, model: _model_type, genes: Sequence[str], lineage: str, backward: bool = False, time_range: _time_range_type = None, clusters: Optional[Sequence[str]] = None, n_points: int = 200, time_key: str = "latent_time", cluster_key: str = "clusters", norm: bool = True, recompute: bool = False, callback: _callback_type = None, ncols: int = 3, sharey: Union[str, bool] = False, key_added: Optional[str] = None, show_progress_bar: bool = True, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}), neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}), louvain_kwargs: Dict = MappingProxyType({}), **kwargs, ) -> None: """ Cluster gene expression trends within a lineage and plot the clusters. This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive development along a given lineage. Consider running this function on a subset of genes which are potential lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`. Parameters ---------- %(adata)s %(model)s %(genes)s lineage Name of the lineage for which to cluster the genes. %(backward)s %(time_ranges)s clusters Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when plotting previously computed clusters. n_points Number of points used for prediction. time_key Key in ``adata.obs`` where the pseudotime is stored. cluster_key Key in ``adata.obs`` where the clustering is stored. norm Whether to z-normalize each trend to have zero mean, unit variance. recompute If `True`, recompute the clustering, otherwise try to find already existing one. %(model_callback)s ncols Number of columns for the plot. sharey Whether to share y-axis across multiple plots. key_added Postfix to add when saving the results to ``adata.uns``. %(parallel)s %(plotting)s pca_kwargs Keyword arguments for :func:`scanpy.pp.pca`. neighbors_kwargs Keyword arguments for :func:`scanpy.pp.neighbors`. louvain_kwargs Keyword arguments for :func:`scanpy.tl.louvain`. **kwargs: Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(just_plots)s Updates ``adata.uns`` with the following key: - ``lineage_{lineage}_trend_{key_added}`` - an :class:`anndata.AnnData` object of shape ``(n_genes, n_points)`` containing the clustered genes. """ import scanpy as sc from anndata import AnnData as _AnnData lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") _ = adata.obsm[lineage_key][lineage] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False)) key_to_add = f"lineage_{lineage}_trend" if key_added is not None: logg.debug(f"Adding key `{key_added!r}`") key_to_add += f"_{key_added}" if recompute or key_to_add not in adata.uns: kwargs["time_key"] = time_key # kwargs for the model.prepare kwargs["n_test_points"] = n_points kwargs["backward"] = backward models = _create_models(model, genes, [lineage]) callbacks = _create_callbacks(adata, callback, genes, [lineage], **kwargs) backend = _get_backend(model, backend) n_jobs = _get_n_cores(n_jobs, len(genes)) start = logg.info(f"Computing gene trends using `{n_jobs}` core(s)") trends = parallelize( _cluster_lineages_helper, genes, as_array=True, unit="gene", n_jobs=n_jobs, backend=backend, extractor=np.vstack, show_progress_bar=show_progress_bar, )(models, callbacks, lineage, time_range, **kwargs) logg.info(" Finish", time=start) trends = trends.T if norm: logg.debug("Normalizing using `StandardScaler`") _ = StandardScaler(copy=False).fit_transform(trends) trends = _AnnData(trends.T) trends.obs_names = genes # sanity check if trends.n_obs != len(genes): raise RuntimeError( f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`." ) if n_points is not None and trends.n_vars != n_points: raise RuntimeError( f"Expected to find `{n_points}` points, found `{trends.n_vars}`." ) pca_kwargs = dict(pca_kwargs) n_comps = pca_kwargs.pop( "n_comps", min(50, kwargs.get("n_test_points"), len(genes)) - 1) # default value sc.pp.pca(trends, n_comps=n_comps, **pca_kwargs) sc.pp.neighbors(trends, **neighbors_kwargs) louvain_kwargs = dict(louvain_kwargs) louvain_kwargs["key_added"] = cluster_key sc.tl.louvain(trends, **louvain_kwargs) adata.uns[key_to_add] = trends else: logg.info(f"Loading data from `adata.uns[{key_to_add!r}]`") trends = adata.uns[key_to_add] if clusters is None: if cluster_key not in trends.obs: raise KeyError(f"Invalid cluster key `{cluster_key!r}`.") clusters = trends.obs[cluster_key].cat.categories nrows = int(np.ceil(len(clusters) / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * 10, nrows * 10) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) j = 0 for j, (ax, c) in enumerate(zip(axes, clusters)): # noqa data = trends[trends.obs[cluster_key] == c].X mean, sd = np.mean(data, axis=0), np.var(data, axis=0) sd = np.sqrt(sd) for i in range(data.shape[0]): ax.plot(data[i], color="gray", lw=0.5) ax.plot(mean, lw=2, color="black") ax.plot(mean - sd, lw=1.5, color="black", linestyle="--") ax.plot(mean + sd, lw=1.5, color="black", linestyle="--") ax.fill_between(range(len(mean)), mean - sd, mean + sd, color="black", alpha=0.1) ax.set_title(f"Cluster {c}") ax.set_xticks([]) if not sharey: ax.set_yticks([]) for j in range(j + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save)
def _create_models(model: _input_model_type, obs: Sequence[str], lineages: Sequence[Optional[str]]) -> _return_model_type: """ Create models for each gene and lineage. Parameters ---------- obs Sequence of observations, such as genes. lineages Sequence of genes. Returns ------- The created models. """ def process_lineages(obs_name: str, lin_names: Union[BaseModel, Dict[Optional[str], Any]]): if isinstance(lin_names, BaseModel): # sharing the same models for all lineages for lin_name in lineages: models[obs_name][lin_name] = copy(lin_names) return if not isinstance(lin_names, dict): raise TypeError( f"Expected the model to be either a lineage specific `dict` or a `BaseModel`, " f"found `{type(lin_names).__name__!r}`.") lin_rest_model = lin_names.get("*", None) # do not pop if lin_rest_model is not None and not isinstance( lin_rest_model, BaseModel): raise TypeError( f"Expected the lineage fallback model for gene `{obs_name!r}` to be of type `BaseModel`, " f"found `{type(lin_rest_model).__name__!r}`.") for lin_name, mod in lin_names.items(): if lin_name == "*": continue if not isinstance(mod, BaseModel): raise TypeError( f"Expected the model for gene `{obs_name!r}` and lineage `{lin_name!r}` " f"to be of type `BaseModel`, found `{type(mod).__name__!r}`." ) models[obs_name][lin_name] = copy(mod) if set(models[obs_name].keys()) & lineages == lineages: return if lin_rest_model is not None: for lin_name in lineages - set(models[obs_name].keys()): models[obs_name][lin_name] = copy(lin_rest_model) else: raise ValueError( _ERROR_INCOMPLETE_SPEC.format( f"all lineages for gene `{obs_name!r}`")) if not len(lineages): raise ValueError("No lineages have been selected.") if not len(obs): raise ValueError("No genes have been selected.") if isinstance(model, BaseModel): return { o: {lin: copy(model) for lin in _unique_order_preserving(lineages)} for o in _unique_order_preserving(obs) } lineages, obs = ( set(_unique_order_preserving(lineages)), set(_unique_order_preserving(obs)), ) models = defaultdict(dict) if isinstance(model, dict): obs_rest_model = model.pop("*", None) if obs_rest_model is not None and not isinstance( obs_rest_model, BaseModel): raise TypeError( f"Expected the gene fallback model to be of type `BaseModel`, " f"found `{type(obs_rest_model).__name__!r}`.") for obs_name, lin_names in model.items(): process_lineages(obs_name, lin_names) if obs_rest_model is not None: for obs_name in obs - set(model.keys()): process_lineages(obs_name, model.get(obs_name, obs_rest_model)) elif set(model.keys()) != obs: raise ValueError( _ERROR_INCOMPLETE_SPEC.format( f"genes `{list(obs - set(model.keys()))}`.")) else: raise TypeError( f"Class `{type(model).__name__!r}` must be of type `BaseModel` or " f"a gene and lineage specific `dict` of `BaseModel`..") if set(models.keys()) & obs != obs: raise ValueError( f"Missing gene models for the following genes: `{list(obs - set(models.keys()))}`." ) for gene, vs in models.items(): if set(vs.keys()) & lineages != lineages: raise ValueError( f"Missing lineage models for the gene `{gene!r}`: `{list(lineages - set(vs.keys()))}`." ) return models
def circular_projection( adata: AnnData, keys: Union[str, Sequence[str]], backward: bool = False, lineages: Optional[Union[str, Sequence[str]]] = None, early_cells: Optional[Union[Mapping[str, Sequence[str]], Sequence[str]]] = None, lineage_order: Optional[Literal["default", "optimal"]] = None, metric: Union[str, Callable, np.ndarray, pd.DataFrame] = "correlation", normalize_by_mean: bool = True, ncols: int = 4, space: float = 0.25, use_raw: bool = False, text_kwargs: Mapping[str, Any] = MappingProxyType({}), labeldistance: float = 1.25, labelrot: Union[Literal["default", "best"], float] = "best", show_edges: bool = True, key_added: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs, ): r""" Plot absorption probabilities on a circular embedding as done in [Velten17]_. Parameters ---------- %(adata)s keys Keys in :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.var_names`. Additional keys are: - `'kl_divergence'` - as in [Velten17]_, computes KL-divergence between the fate probabilities of a cell and the average fate probabilities. See ``early_cells`` for more information. - `'entropy'` - as in [Setty19]_, computes entropy over a cells fate probabilities. %(backward)s lineages Lineages to plot. If `None`, plot all lineages. early_cells Cell ids or a mask marking early cells used to define the average fate probabilities. If `None`, use all cells. Only used when `'kl_divergence'` is in ``keys``. If a :class:`dict`, key specifies a cluster key in :attr:`anndata.AnnData.obs` and the values specify cluster labels containing early cells. lineage_order Can be one of the following: - `None` - it will determined automatically, based on the number of lineages. - `'optimal'` - order the lineages optimally by solving the Travelling salesman problem (TSP). Recommended for <= `20` lineages. - `'default'` - use the order as specified in ``lineages``. metric Metric to use when constructing pairwise distance matrix when ``lineage_order = 'optimal'``. For available options, see :func:`sklearn.metrics.pairwise_distances`. normalize_by_mean If `True`, normalize each lineage by its mean probability, as done in [Velten17]_. ncols Number of columns when plotting multiple ``keys``. space Horizontal and vertical space between for :func:`matplotlib.pyplot.subplots_adjust`. use_raw Whether to access :attr:`anndata.AnnData.raw` when there are ``keys`` in :attr:`anndata.AnnData.var_names`. text_kwargs Keyword arguments for :func:`matplotlib.pyplot.text`. labeldistance Distance at which the lineage labels will be drawn. labelrot How to rotate the labels. Valid options are: - `'best'` - rotate labels so that they are easily readable. - `'default'` - use :mod:`matplotlib`'s default. - `None` - same as `'default'`. If a :class:`float`, all labels will be rotated by this many degrees. show_edges Whether to show the edges surrounding the simplex. key_added Key in :attr:`anndata.AnnData.obsm` where to add the circular embedding. If `None`, it will be set to `'X_fate_simplex_{fwd,bwd}'`, based on ``backward``. %(plotting)s kwargs Keyword arguments for :func:`scvelo.pl.scatter`. Returns ------- %(just_plots)s Also updates ``adata`` with the following fields: - :attr:`anndata.AnnData.obsm` ``['{key_added}']``: the circular projection. - :attr:`anndata.AnnData.obs` ``['to_{initial,terminal}_states_{method}']``: the priming degree, if a method is present in ``keys``. """ if labeldistance is not None and labeldistance < 0: raise ValueError( f"Expected `delta` to be positive, found `{labeldistance}`.") if labelrot is None: labelrot = LabelRot.DEFAULT if isinstance(labelrot, str): labelrot = LabelRot(labelrot) suffix = "bwd" if backward else "fwd" if key_added is None: key_added = "X_fate_simplex_" + suffix if isinstance(keys, str): keys = (keys, ) keys = _unique_order_preserving(keys) keys_ = _check_collection( adata, keys, "obs", key_name="Observation", raise_exc=False) + _check_collection(adata, keys, "var_names", key_name="Gene", raise_exc=False, use_raw=use_raw) haystack = {s.s for s in PrimingDegree} keys = keys_ + [k for k in keys if k in haystack] keys = _unique_order_preserving(keys) if not len(keys): raise ValueError("No valid keys have been selected.") lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") probs = adata.obsm[lineage_key] if isinstance(lineages, str): lineages = (lineages, ) elif lineages is None: lineages = probs.names probs: Lineage = adata.obsm[lineage_key][lineages] n_lin = probs.shape[1] if n_lin <= 2: raise ValueError(f"Expected at least `3` lineages, found `{n_lin}`") X = probs.X.copy() if normalize_by_mean: X /= np.mean(X, axis=0)[None, :] X /= X.sum(1)[:, None] # this happens when cells for sel. lineages sum to 1 (or when the lineage average is 0, which is unlikely) X = np.nan_to_num(X, nan=1.0 / n_lin, copy=False) if lineage_order is None: lineage_order = LineageOrder.OPTIMAL if n_lin <= 15 else LineageOrder.DEFAULT logg.debug(f"Set ordering to `{lineage_order}`") lineage_order = LineageOrder(lineage_order) if lineage_order == LineageOrder.OPTIMAL: logg.info(f"Solving TSP for `{n_lin}` states") _, order = _get_optimal_order(X, metric=metric) else: order = np.arange(n_lin) probs = probs[:, order] X = X[:, order] angle_vec = np.linspace(0, 2 * np.pi, n_lin, endpoint=False) angle_vec_sin = np.cos(angle_vec) angle_vec_cos = np.sin(angle_vec) x = np.sum(X * angle_vec_sin, axis=1) y = np.sum(X * angle_vec_cos, axis=1) adata.obsm[key_added] = np.c_[x, y] nrows = int(np.ceil(len(keys) / ncols)) fig, ax = plt.subplots( nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5) if figsize is None else figsize, dpi=dpi, ) fig.subplots_adjust(wspace=space, hspace=space) axes = np.ravel([ax]) text_kwargs = dict(text_kwargs) text_kwargs["ha"] = "center" text_kwargs["va"] = "center" _i = 0 for _i, (k, ax) in enumerate(zip(keys, axes)): set_lognorm, colorbar = False, kwargs.pop("colorbar", True) try: _ = PrimingDegree(k) logg.debug(f"Calculating priming degree using `method={k}`") val = probs.priming_degree(method=k, early_cells=early_cells) k = f"{lineage_key}_{k}" adata.obs[k] = val except ValueError: pass scv.pl.scatter( adata, basis=key_added, color=k, show=False, ax=ax, use_raw=use_raw, norm=LogNorm() if set_lognorm else None, colorbar=colorbar, **kwargs, ) if colorbar and set_lognorm: cbar = ax.collections[0].colorbar cax = cbar.locator.axis ticks = cax.minor.locator.tick_values(cbar.vmin, cbar.vmax) ticks = [ticks[0], ticks[len(ticks) // 2 + 1], ticks[-1]] cbar.set_ticks(ticks) cbar.set_ticklabels([f"{t:.2f}" for t in ticks]) cbar.update_ticks() patches, texts = ax.pie( np.ones_like(angle_vec), labeldistance=labeldistance, rotatelabels=True, labels=probs.names[::-1], startangle=-360 / len(angle_vec) / 2, counterclock=False, textprops=text_kwargs, ) for patch in patches: patch.set_visible(False) # clockwise for color, text in zip(probs.colors[::-1], texts): if isinstance(labelrot, (int, float)): text.set_rotation(labelrot) elif labelrot == LabelRot.BEST: rot = text.get_rotation() text.set_rotation(rot + 90 + (1 - rot // 180) * 180) elif labelrot != LabelRot.DEFAULT: raise NotImplementedError( f"Label rotation `{labelrot}` is not yet implemented.") text.set_color(color) if not show_edges: continue for i, color in enumerate(probs.colors): next = (i + 1) % n_lin x = 1.04 * np.linspace(angle_vec_sin[i], angle_vec_sin[next], _N) y = 1.04 * np.linspace(angle_vec_cos[i], angle_vec_cos[next], _N) points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) cmap = LinearSegmentedColormap.from_list( "abs_prob_cmap", [color, probs.colors[next]], N=_N) lc = LineCollection(segments, cmap=cmap, zorder=-1) lc.set_array(np.linspace(0, 1, _N)) lc.set_linewidth(2) ax.add_collection(lc) for j in range(_i + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save)
def prepare( self, cluster: str, clusters: Optional[Sequence[Any]] = None, time_points: Optional[Sequence[Numeric_t]] = None, ) -> "FlowPlotter": """ Prepare itself for plotting by computing flow and contingency matrix. Parameters ---------- cluster Source cluster for flow calculation. clusters Target clusters for flow calculation. If `None`, use all clusters. time_points Restrict flow calculation only to these time points. If `None`, use all time points. Returns ------- Modifies and return self. """ if clusters is None: self._clusters = self.clusters.cat.categories else: clusters = _unique_order_preserving([cluster] + list(clusters)) mask = self.clusters.isin(clusters).values self._adata = self._adata[mask] if not self._adata.n_obs: raise ValueError("No valid clusters have been selected.") self._tmat = self._tmat[mask, :][:, mask] self._clusters = [ c for c in clusters if c in self.clusters.cat.categories ] if cluster not in self._clusters: raise ValueError(f"Invalid source cluster `{cluster!r}`.") if len(self._clusters) < 2: raise ValueError( f"Expected at least `2` clusters, found `{len(clusters)}`.") if time_points is not None: time_points = _unique_order_preserving(time_points) if len(time_points) < 2: raise ValueError( f"Expected at least `2` time points, found `{len(time_points)}`." ) mask = self.time.isin(time_points) self._adata = self._adata[mask] if not self._adata.n_obs: raise ValueError("No valid time points have been selected.") self._tmat = self._tmat[mask, :][:, mask] time_points = list( zip(self.time.cat.categories[:-1], self.time.cat.categories[1:])) logg.info( f"Computing flow from `{cluster}` into `{len(self._clusters) - 1}` cluster(s) " f"in `{len(time_points)}` time points") self._cluster = cluster self._cmat = self.compute_contingency_matrix() self._flow = self.compute_flow(time_points, cluster) return self
def gene_trends( adata: AnnData, model: _model_type, genes: Union[str, Sequence[str]], lineages: Optional[Union[str, Sequence[str]]] = None, backward: bool = False, data_key: str = "X", time_key: str = "latent_time", time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None, callback: _callback_type = None, conf_int: bool = True, same_plot: bool = False, hide_cells: bool = False, perc: Optional[Union[Tuple[float, float], Sequence[Tuple[float, float]]]] = None, lineage_cmap: Optional[matplotlib.colors.ListedColormap] = None, abs_prob_cmap: matplotlib.colors.ListedColormap = cm.viridis, cell_color: str = "black", cell_alpha: float = 0.6, lineage_alpha: float = 0.2, size: float = 15, lw: float = 2, show_cbar: bool = True, margins: float = 0.015, sharex: Optional[Union[str, bool]] = None, sharey: Optional[Union[str, bool]] = None, gene_as_title: Optional[bool] = None, legend_loc: Optional[str] = "best", ncols: int = 2, suptitle: Optional[str] = None, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, show_progres_bar: bool = True, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, plot_kwargs: Mapping = MappingProxyType({}), **kwargs, ) -> None: """ Plot gene expression trends along lineages. Each lineage is defined via it's lineage weights which we compute using :func:`cellrank.tl.lineages`. This function accepts any model based off :class:`cellrank.ul.models.BaseModel` to fit gene expression, where we take the lineage weights into account in the loss function. Parameters ---------- %(adata)s %(model)s %(genes)s lineages Names of the lineages to plot. If `None`, plot all lineages. %(backward)s data_key Key in ``adata.layers`` or `'X'` for ``adata.X`` where the data is stored. time_key Key in ``adata.obs`` where the pseudotime is stored. %(time_ranges)s %(model_callback)s conf_int Whether to compute and show confidence intervals. same_plot Whether to plot all lineages for each gene in the same plot. hide_cells If `True`, hide all cells. perc Percentile for colors. Valid values are in interval `[0, 100]`. This can improve visualization. Can be specified individually for each lineage. lineage_cmap Colormap to use when coloring in the lineages. If `None` and ``same_plot``, use the corresponding colors in ``adata.uns``, otherwise use `'black'`. abs_prob_cmap Colormap to use when visualizing the absorption probabilities for each lineage. Only used when ``same_plot=False``. cell_color Color of the cells when not visualizing absorption probabilities. Only used when ``same_plot=True``. cell_alpha Alpha channel for cells. lineage_alpha Alpha channel for lineage confidence intervals. size Size of the points. lw Line width of the smoothed values. show_cbar Whether to show colorbar. Always shown when percentiles for lineages differ. Only used when ``same_plot=False``. margins Margins around the plot. sharex Whether to share x-axis. Valid options are `'row'`, `'col'` or `'none'`. sharey Whether to share y-axis. Valid options are `'row'`, `'col'` or `'none'`. gene_as_title Whether to show gene names as titles instead on y-axis. legend_loc Location of the legend displaying lineages. Only used when `same_plot=True`. ncols Number of columns of the plot when pl multiple genes. Only used when ``same_plot=True``. suptitle Suptitle of the figure. %(parallel)s %(plotting)s plot_kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.plot`. **kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(just_plots)s """ if isinstance(genes, str): genes = [genes] genes = _unique_order_preserving(genes) if data_key != "obs": _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False)) else: _check_collection(adata, genes, "obs", use_raw=kwargs.get("use_raw", False)) ln_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if ln_key not in adata.obsm: raise KeyError(f"Lineages key `{ln_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[ln_key].names elif isinstance(lineages, str): lineages = [lineages] elif all(map(lambda ln: ln is None, lineages)): # no lineage, all the weights are 1 lineages = [None] show_cbar = False logg.debug("All lineages are `None`, setting the weights to `1`") lineages = _unique_order_preserving(lineages) if same_plot: gene_as_title = True if gene_as_title is None else gene_as_title sharex = "all" if sharex is None else sharex sharey = "none" if sharey is None else sharey ncols = len(genes) if ncols >= len(genes) else ncols nrows = int(np.ceil(len(genes) / ncols)) else: gene_as_title = False if gene_as_title is None else gene_as_title sharex = "col" if sharex is None else sharex sharey = ( "none" if hide_cells else "row") if sharey is None else sharey nrows = len(genes) ncols = len(lineages) fig, axes = plt.subplots( nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, figsize=(6 * ncols, 4 * nrows) if figsize is None else figsize, constrained_layout=True, ) axes = np.reshape(axes, (-1, ncols)) _ = adata.obsm[ln_key][[lin for lin in lineages if lin is not None]] if isinstance(time_range, (tuple, float, int, type(None))): time_range = [time_range] * len(lineages) elif len(time_range) != len(lineages): raise ValueError( f"Expected time ranges to be of length `{len(lineages)}`, found `{len(time_range)}`." ) kwargs["time_key"] = time_key kwargs["data_key"] = data_key kwargs["backward"] = backward callbacks = _create_callbacks(adata, callback, genes, lineages, **kwargs) kwargs["conf_int"] = conf_int # prepare doesnt take or need this models = _create_models(model, genes, lineages) plot_kwargs = dict(plot_kwargs) if plot_kwargs.get("xlabel", None) is None: plot_kwargs["xlabel"] = time_key n_jobs = _get_n_cores(n_jobs, len(genes)) backend = _get_backend(model, backend) start = logg.info(f"Computing trends using `{n_jobs}` core(s)") models = parallelize( _fit_gene_trends, genes, unit="gene" if data_key != "obs" else "obs", backend=backend, n_jobs=n_jobs, extractor=lambda modelss: {k: v for m in modelss for k, v in m.items()}, show_progress_bar=show_progres_bar, )(models, callbacks, lineages, time_range, **kwargs) logg.info(" Finish", time=start) logg.info("Plotting trends") cnt = 0 for row in range(len(axes)): for col in range(len(axes[row])): if cnt >= len(genes): break gene = genes[cnt] _trends_helper( adata, models, gene=gene, lineage_names=lineages, ln_key=ln_key, same_plot=same_plot, hide_cells=hide_cells, perc=perc, lineage_cmap=lineage_cmap, abs_prob_cmap=abs_prob_cmap, cell_color=cell_color, alpha=cell_alpha, lineage_alpha=lineage_alpha, size=size, lw=lw, show_cbar=show_cbar, margins=margins, sharey=sharey, gene_as_title=gene_as_title, legend_loc=legend_loc, dpi=dpi, figsize=figsize, fig=fig, axes=axes[row, col] if same_plot else axes[cnt], show_ylabel=col == 0, show_lineage=cnt == 0 or same_plot, show_xticks_and_label=((row + 1) * ncols + col >= len(genes)) if same_plot else (cnt == len(axes) - 1), **plot_kwargs, ) cnt += 1 if same_plot and (col != ncols): for ax in np.ravel(axes)[cnt:]: ax.remove() fig.suptitle(suptitle) if save is not None: save_fig(fig, save)