def _run_in_parallel(fn: Callable, conn: csr_matrix, **kwargs) -> Any: fname = fn.__name__ if fname == "_run_stochastic": if not _HAS_JAX: raise RuntimeError( "Install `jax` and `jaxlib` as `pip install jax jaxlib`.") ixs = np.argsort(np.array((conn != 0).sum(1)).ravel())[::-1] else: ixs = np.arange(conn.shape[0]) np.random.shuffle(ixs) unit = ("sample" if (fname == "_run_mc") and kwargs.get("n_samples", 1) > 1 else "cell") kwargs["indices"] = conn.indices kwargs["indptr"] = conn.indptr return parallelize( fn, ixs, as_array=False, extractor=lambda res: _reconstruct_one(np.concatenate(res, axis=-1), conn, ixs), unit=unit, **_filter_kwargs(parallelize, **kwargs), )(**_filter_kwargs(fn, **kwargs))
def copy(self) -> "SimpleNaryExpression": """%(copy)s""" # noqa: D400, D401 constructor = type(self) kwargs = {"op_name": self._op_name, "fn": self._fn} # preserve the type information so that combination can properly work # we test for this sne = constructor([copy(k) for k in self], **_filter_kwargs(constructor, **kwargs)) # TODO: copying's a bit buggy - the parent stays the same # we could disallow it for inner expressions sne._transition_matrix = copy(self._transition_matrix) sne._condition_number = self._cond_num sne._normalize = self._normalize return sne
def __init__( self, adata: AnnData, n_splines: Optional[int] = 10, spline_order: int = 3, distribution: str = "gamma", link: str = "log", max_iter: int = 2000, expectile: Optional[float] = None, use_default_conf_int: bool = False, grid: Optional[Mapping] = None, spline_kwargs: Mapping = MappingProxyType({}), **kwargs, ): term = s( 0, spline_order=spline_order, n_splines=n_splines, penalties=["derivative", "l2"], **_filter_kwargs(s, **{**{"lam": 3}, **spline_kwargs}), ) link = GamLinkFunction(link) distribution = GamDistribution(distribution) if distribution == GamDistribution.GAUSS: distribution = GamDistribution.NORMAL if expectile is not None: if not (0 < expectile < 1): raise ValueError( f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`." ) if distribution != "normal" or link != "identity": logg.warning( f"Expectile GAM works only with `normal` distribution and `identity` link function," f"found `{distribution!r}` distribution and {link!r} link functions." ) model = ExpectileGAM( term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs ) else: gam = _gams[ distribution, link ] # doing it like this ensure that user can specify scale kwargs["link"] = link.s kwargs["distribution"] = distribution.s model = gam( term, max_iter=max_iter, verbose=False, **_filter_kwargs(gam.__init__, **kwargs), ) super().__init__(adata, model=model) self._use_default_conf_int = use_default_conf_int if grid is None: self._grid = None elif isinstance(grid, dict): self._grid = _copy(grid) elif isinstance(grid, str): self._grid = object() if grid == "default" else None else: raise TypeError( f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`." )
def _plot_continuous( self, probs: Optional[Lineage], prop: str, lineages: Optional[Union[str, Iterable[str]]] = None, cluster_key: Optional[str] = None, mode: str = "embedding", time_key: str = "latent_time", title: Optional[str] = None, same_plot: bool = True, cmap: Union[str, mpl.colors.ListedColormap] = cm.viridis, **kwargs, ) -> None: """ Plot continuous observations such as macrostates memberships or lineages in an embedding. Parameters ---------- lineages Plot only these lineages. If `None`, plot all lineages. cluster_key Key from :attr:`adata` ``.obs`` for plotting categorical observations. %(time_mode)s time_key Key from :attr:`adata` ``.obs`` to use as a pseudotime ordering of the cells. title Either `None`, in which case titles are ``'{to,from} {terminal,initial} {state}'``, or an array of titles, one per lineage. same_plot Whether to plot the lineages on the same plot using color gradients when ``mode='embedding'``. cmap Colormap to use. %(basis)s kwargs Keyword arguments for :func:`scvelo.pl.scatter`. Returns ------- %(just_plots)s """ if probs is None: raise RuntimeError( f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`.") if isinstance(lineages, str): lineages = [lineages] if lineages is None: lineages = probs.names A = probs else: A = probs[lineages] if not len(lineages): raise RuntimeError( "Nothing to plot because empty subset has been selected.") prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD same_plot = same_plot and mode == "embedding" # set this silently A = A.copy() # the below code modifies stuff inplace X = A.X # list(A.T) behaves differently, because it's Lineage if X.shape[1] == 1: same_plot = ( False # because color_gradients for 1 state is buggy (looks empty) ) # this is the case for only 1 recurrent class - all cells have prob. 1 of going there # however, matplotlib's plotting really picks up the slightest differences in the colormap, here we set # everything to one, if applicable if np.allclose(X, 1.0): X = np.ones_like(X) for col in X.T: mask = ~np.isclose(col, 1.0) # change the maximum value - the 1 is artificial and obscures the color scaling if np.sum(mask): max_not_one = np.max(col[mask]) col[~mask] = max_not_one if mode == "time": if time_key not in self.adata.obs.keys(): raise KeyError( f"Time key `{time_key!r}` not found in `adata.obs`.") time = self.adata.obs[time_key] if cluster_key is not None: logg.warning( f"Cluster key `{cluster_key!r}` is ignored when `mode='time'`" ) cluster_key = None color = list(X.T) if title is None: if same_plot: title = [ f"{prop.replace('_', ' ')} " f"({DirectionPlot.BACKWARD if self.kernel.backward else Direction.FORWARD})" ] else: title = [f"{prefix} {lin}" for lin in lineages] elif isinstance(title, str): title = [title] if isinstance(cluster_key, str): cluster_key = [cluster_key] elif cluster_key is None: cluster_key = [] if not isinstance(cluster_key, list): cluster_key = list(cluster_key) if not same_plot: color = cluster_key + color title = cluster_key + title if mode == "embedding": if same_plot: # to complement: https://github.com/theislab/scvelo/blob/master/scvelo/plotting/scatter.py#L269 sorted_idx = np.argsort(X, axis=1)[:, ::-1][:, :2] pairs, cnts = np.unique(np.sort(sorted_idx, axis=0), axis=0, return_counts=True) absent = set(A.names) - set(A.names[pairs[cnts > 1].flatten()]) if absent: # can't print which because it would be inaccurate, i.e. we would print the legend # is missing when in fact, it would be visible # this has been checked exactly according to scVelo's logic logg.warning("Legend for some lineages may be missing") kwargs["color_gradients"] = A if len(cluster_key): logg.warning( "Ignoring `cluster_key` when plotting continuous observations in the same plot" ) # kwargs["color"] = cluster_key this results in a bug, cluster_key data is overwritten, will make a PR else: kwargs["color"] = color if probs.shape[1] == 1 and prop in (P.MACRO_MEMBER.s, P.TERM_PROBS.s): if "perc" not in kwargs: logg.warning( "Did not detect percentile for stationary distribution. Setting `perc=[0, 95]`" ) kwargs["perc"] = [0, 95] kwargs["color"] = X kwargs.pop("color_gradients", None) scv.pl.scatter( self.adata, title=title, color_map=cmap, **_filter_kwargs(scv.pl.scatter, **kwargs), ) elif mode == "time": scv.pl.scatter( self.adata, x=time, color_map=cmap, y=color, title=title, xlabel=[time_key] * len(title), ylabel=["probability"] * len(title), **_filter_kwargs(scv.pl.scatter, **kwargs), ) else: raise ValueError( f"Invalid mode `{mode!r}`. Valid options are: `'embedding'` or `'time'`." )
def _plot_discrete( self, data: pd.Series, prop: str, lineages: Optional[Union[str, Sequence[str]]] = None, cluster_key: Optional[str] = None, same_plot: bool = True, title: Optional[Union[str, List[str]]] = None, **kwargs, ) -> None: """ Plot the states for each uncovered lineage. Parameters ---------- lineages Plot only these lineages. If `None`, plot all lineages. cluster_key Key from :attr:`adata` ``.obs`` for plotting categorical observations. same_plot Whether to plot the lineages on the same plot or separately. title The title of the plot. %(basis)s kwargs Keyword arguments for :func:`scvelo.pl.scatter`. Returns ------- %(just_plots)s """ if data is None: raise RuntimeError( f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`.") if not is_categorical_dtype(data): raise TypeError( f"Expected property `.{prop}` to be categorical, found `{type(data).__name__!r}`." ) if prop in (P.ABS_PROBS.s, P.TERM.s): colors = getattr(self, A.TERM_COLORS.v, None) elif prop == P.MACRO.v: colors = getattr(self, A.MACRO_COLORS.v, None) else: logg.debug("No colors found. Creating new ones") colors = _create_categorical_colors(len(data.cat.categories)) colors = dict(zip(data.cat.categories, colors)) if ( lineages is not None ): # these are states per-se, but I want to keep the arg names for dispatch the same if isinstance(lineages, str): lineages = [lineages] for state in lineages: if state not in data.cat.categories: raise ValueError( f"Invalid state `{state!r}`. Valid options are `{list(data.cat.categories)}`." ) data = data.copy() to_remove = list(set(data.cat.categories) - set(lineages)) if len(to_remove) == len(data.cat.categories): raise RuntimeError( "Nothing to plot because empty subset has been selected.") for state in to_remove: data[data == state] = np.nan data = data.cat.remove_categories(to_remove) if cluster_key is None: cluster_key = [] elif isinstance(cluster_key, str): cluster_key = [cluster_key] if not isinstance(cluster_key, list): cluster_key = list(cluster_key) same_plot = same_plot or len(data.cat.categories) == 1 kwargs["legend_loc"] = kwargs.get("legend_loc", "on data") with RandomKeys(self.adata, None if same_plot else len(data.cat.categories), where="obs") as keys: if same_plot: key = keys[0] self.adata.obs[key] = data self.adata.uns[f"{key}_colors"] = [ colors[c] for c in data.cat.categories ] if title is None: title = ( f"{prop.replace('_', ' ')} " f"({Direction.BACKWARD if self.kernel.backward else Direction.FORWARD})" ) if isinstance(title, str): title = [title] scv.pl.scatter( self.adata, title=cluster_key + title, color=cluster_key + keys, **_filter_kwargs(scv.pl.scatter, **kwargs), ) else: for key, cat in zip(keys, data.cat.categories): d = data.copy() d[data != cat] = None d = d.cat.set_categories([cat]) self.adata.obs[key] = d self.adata.uns[f"{key}_colors"] = [colors[cat]] scv.pl.scatter( self.adata, color=cluster_key + keys, title=(cluster_key + [ f"{_initial if self.kernel.backward else _terminal} state {c}" for c in data.cat.categories ]) if title is None else title, **_filter_kwargs(scv.pl.scatter, **kwargs), )
def _plot_continuous( self, probs: Optional[Lineage], prop: str, diff_potential: Optional[pd.Series] = None, lineages: Optional[Union[str, Iterable[str]]] = None, cluster_key: Optional[str] = None, mode: str = "embedding", time_key: str = "latent_time", show_dp: bool = True, title: Optional[str] = None, same_plot: bool = True, cmap: Union[str, mpl.colors.ListedColormap] = cm.viridis, **kwargs, ) -> None: """ Plot continuous observations, such as lineages, in an embedding. Parameters ---------- lineages Plot only these lineages. If `None`, plot all lineages. cluster_key Key from :paramref:`adata` ``.obs`` for plotting categorical observations. %(time_mode)s time_key Key from :paramref:`adata` ``.obs`` to use as a pseudotime ordering of the cells. title Either `None`, in which case titles are ``'{to, from} {terminal, initial} {state}'``, or an array of titles, one per lineage. same_plot Whether to plot the lineages on the same plot using color gradients when ``mode='embedding'``. cmap Colormap to use. **kwargs Keyword arguments for :func:`scvelo.pl.scatter`. Returns ------- %(just_plots)s """ if probs is None: raise RuntimeError( f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`." ) if isinstance(lineages, str): lineages = [lineages] if lineages is None: lineages = probs.names A = probs else: A = probs[lineages] if not len(lineages): raise RuntimeError( "Nothing to plot because empty subset has been selected." ) prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD same_plot = same_plot and mode == "embedding" # set this silently diff_potential = ( [diff_potential.values] if show_dp and not same_plot and diff_potential is not None and probs.shape[1] > 1 else [] ) A = A.copy() # the below code modifies stuff inplace X = A.X # list(A.T) behaves differently, because it's Lineage if X.shape[1] == 1: same_plot = ( False # because color_gradients for 1 state is buggy (looks empty) ) # this is the case for only 1 recurrent class - all cells have prob. 1 of going there # however, matplotlib's plotting really picks up the slightest differences in the colormap, here we set # everything to one, if applicable if np.allclose(X, 1.0): X = np.ones_like(X) for col in X.T: mask = ~np.isclose(col, 1.0) # change the maximum value - the 1 is artificial and obscures the color scaling if np.sum(mask): max_not_one = np.max(col[mask]) col[~mask] = max_not_one if mode == "time": if time_key not in self.adata.obs.keys(): raise KeyError(f"Time key `{time_key!r}` not found in `adata.obs`.") time = self.adata.obs[time_key] if cluster_key is not None: logg.warning( f"Cluster key `{cluster_key!r}` is ignored when `mode='time'`" ) cluster_key = None color = list(X.T) + diff_potential if title is None: if same_plot: title = [ f"{prop.replace('_', ' ')} " f"({DirectionPlot.BACKWARD if self.kernel.backward else Direction.FORWARD})" ] else: title = [f"{prefix} {lin}" for lin in lineages] + ( ["differentiation potential"] if diff_potential else [] ) elif isinstance(title, str): title = [title] if isinstance(cluster_key, str): cluster_key = [cluster_key] elif cluster_key is None: cluster_key = [] if not isinstance(cluster_key, list): cluster_key = list(cluster_key) if not same_plot: color = cluster_key + color title = cluster_key + title if mode == "embedding": if same_plot: kwargs["color_gradients"] = A if len(cluster_key): logg.warning( "Ignoring `cluster_key` when plotting probabilities in the same plot" ) # kwargs["color"] = cluster_key this results in a bug, cluster_key data is overwritten, will make a PR else: kwargs["color"] = color if probs.shape[1] == 1 and prop in (P.META_PROBS.s, P.FIN_PROBS.s): if "perc" not in kwargs: logg.warning( "Did not detect percentile for stationary distribution. Setting `perc=[0, 95]`" ) kwargs["perc"] = [0, 95] kwargs["color"] = X kwargs.pop("color_gradients", None) scv.pl.scatter( self.adata, title=title, color_map=cmap, **_filter_kwargs(scv.pl.scatter, **kwargs), ) elif mode == "time": scv.pl.scatter( self.adata, x=time, color_map=cmap, y=color, title=title, xlabel=[time_key] * len(title), ylabel=["probability"] * len(title), **_filter_kwargs(scv.pl.scatter, **kwargs), ) else: raise ValueError( f"Invalid mode `{mode!r}`. Valid options are: `'embedding'` or `'time'`." )