Beispiel #1
0
def _run_in_parallel(fn: Callable, conn: csr_matrix, **kwargs) -> Any:
    fname = fn.__name__
    if fname == "_run_stochastic":
        if not _HAS_JAX:
            raise RuntimeError(
                "Install `jax` and `jaxlib` as `pip install jax jaxlib`.")
        ixs = np.argsort(np.array((conn != 0).sum(1)).ravel())[::-1]
    else:
        ixs = np.arange(conn.shape[0])
        np.random.shuffle(ixs)

    unit = ("sample" if (fname == "_run_mc") and kwargs.get("n_samples", 1) > 1
            else "cell")
    kwargs["indices"] = conn.indices
    kwargs["indptr"] = conn.indptr

    return parallelize(
        fn,
        ixs,
        as_array=False,
        extractor=lambda res: _reconstruct_one(np.concatenate(res, axis=-1),
                                               conn, ixs),
        unit=unit,
        **_filter_kwargs(parallelize, **kwargs),
    )(**_filter_kwargs(fn, **kwargs))
Beispiel #2
0
    def copy(self) -> "SimpleNaryExpression":
        """%(copy)s"""  # noqa: D400, D401
        constructor = type(self)
        kwargs = {"op_name": self._op_name, "fn": self._fn}

        # preserve the type information so that combination can properly work
        # we test for this
        sne = constructor([copy(k) for k in self],
                          **_filter_kwargs(constructor, **kwargs))

        # TODO: copying's a bit buggy - the parent stays the same
        # we could disallow it for inner expressions
        sne._transition_matrix = copy(self._transition_matrix)
        sne._condition_number = self._cond_num
        sne._normalize = self._normalize

        return sne
Beispiel #3
0
    def __init__(
        self,
        adata: AnnData,
        n_splines: Optional[int] = 10,
        spline_order: int = 3,
        distribution: str = "gamma",
        link: str = "log",
        max_iter: int = 2000,
        expectile: Optional[float] = None,
        use_default_conf_int: bool = False,
        grid: Optional[Mapping] = None,
        spline_kwargs: Mapping = MappingProxyType({}),
        **kwargs,
    ):
        term = s(
            0,
            spline_order=spline_order,
            n_splines=n_splines,
            penalties=["derivative", "l2"],
            **_filter_kwargs(s, **{**{"lam": 3}, **spline_kwargs}),
        )
        link = GamLinkFunction(link)
        distribution = GamDistribution(distribution)
        if distribution == GamDistribution.GAUSS:
            distribution = GamDistribution.NORMAL

        if expectile is not None:
            if not (0 < expectile < 1):
                raise ValueError(
                    f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`."
                )
            if distribution != "normal" or link != "identity":
                logg.warning(
                    f"Expectile GAM works only with `normal` distribution and `identity` link function,"
                    f"found `{distribution!r}` distribution and {link!r} link functions."
                )
            model = ExpectileGAM(
                term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs
            )
        else:
            gam = _gams[
                distribution, link
            ]  # doing it like this ensure that user can specify scale
            kwargs["link"] = link.s
            kwargs["distribution"] = distribution.s
            model = gam(
                term,
                max_iter=max_iter,
                verbose=False,
                **_filter_kwargs(gam.__init__, **kwargs),
            )
        super().__init__(adata, model=model)
        self._use_default_conf_int = use_default_conf_int

        if grid is None:
            self._grid = None
        elif isinstance(grid, dict):
            self._grid = _copy(grid)
        elif isinstance(grid, str):
            self._grid = object() if grid == "default" else None
        else:
            raise TypeError(
                f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`."
            )
Beispiel #4
0
    def _plot_continuous(
        self,
        probs: Optional[Lineage],
        prop: str,
        lineages: Optional[Union[str, Iterable[str]]] = None,
        cluster_key: Optional[str] = None,
        mode: str = "embedding",
        time_key: str = "latent_time",
        title: Optional[str] = None,
        same_plot: bool = True,
        cmap: Union[str, mpl.colors.ListedColormap] = cm.viridis,
        **kwargs,
    ) -> None:
        """
        Plot continuous observations such as macrostates memberships or lineages in an embedding.

        Parameters
        ----------
        lineages
            Plot only these lineages. If `None`, plot all lineages.
        cluster_key
            Key from :attr:`adata` ``.obs`` for plotting categorical observations.
        %(time_mode)s
        time_key
            Key from :attr:`adata` ``.obs`` to use as a pseudotime ordering of the cells.
        title
            Either `None`, in which case titles are ``'{to,from} {terminal,initial} {state}'``,
            or an array of titles, one per lineage.
        same_plot
            Whether to plot the lineages on the same plot using color gradients when ``mode='embedding'``.
        cmap
            Colormap to use.
        %(basis)s
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        if probs is None:
            raise RuntimeError(
                f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`.")

        if isinstance(lineages, str):
            lineages = [lineages]

        if lineages is None:
            lineages = probs.names
            A = probs
        else:
            A = probs[lineages]

        if not len(lineages):
            raise RuntimeError(
                "Nothing to plot because empty subset has been selected.")

        prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD
        same_plot = same_plot and mode == "embedding"  # set this silently

        A = A.copy()  # the below code modifies stuff inplace
        X = A.X  # list(A.T) behaves differently, because it's Lineage

        if X.shape[1] == 1:
            same_plot = (
                False  # because color_gradients for 1 state is buggy (looks empty)
            )
            # this is the case for only 1 recurrent class - all cells have prob. 1 of going there
            # however, matplotlib's plotting really picks up the slightest differences in the colormap, here we set
            # everything to one, if applicable
            if np.allclose(X, 1.0):
                X = np.ones_like(X)

        for col in X.T:
            mask = ~np.isclose(col, 1.0)
            # change the maximum value - the 1 is artificial and obscures the color scaling
            if np.sum(mask):
                max_not_one = np.max(col[mask])
                col[~mask] = max_not_one

        if mode == "time":
            if time_key not in self.adata.obs.keys():
                raise KeyError(
                    f"Time key `{time_key!r}` not found in `adata.obs`.")
            time = self.adata.obs[time_key]
            if cluster_key is not None:
                logg.warning(
                    f"Cluster key `{cluster_key!r}` is ignored when `mode='time'`"
                )
            cluster_key = None

        color = list(X.T)
        if title is None:
            if same_plot:
                title = [
                    f"{prop.replace('_', ' ')} "
                    f"({DirectionPlot.BACKWARD if self.kernel.backward else Direction.FORWARD})"
                ]
            else:
                title = [f"{prefix} {lin}" for lin in lineages]
        elif isinstance(title, str):
            title = [title]

        if isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        elif cluster_key is None:
            cluster_key = []
        if not isinstance(cluster_key, list):
            cluster_key = list(cluster_key)

        if not same_plot:
            color = cluster_key + color
            title = cluster_key + title

        if mode == "embedding":
            if same_plot:
                # to complement: https://github.com/theislab/scvelo/blob/master/scvelo/plotting/scatter.py#L269
                sorted_idx = np.argsort(X, axis=1)[:, ::-1][:, :2]
                pairs, cnts = np.unique(np.sort(sorted_idx, axis=0),
                                        axis=0,
                                        return_counts=True)
                absent = set(A.names) - set(A.names[pairs[cnts > 1].flatten()])
                if absent:
                    # can't print which because it would be inaccurate, i.e. we would print the legend
                    # is missing when in fact, it would be visible
                    # this has been checked exactly according to scVelo's logic
                    logg.warning("Legend for some lineages may be missing")

                kwargs["color_gradients"] = A
                if len(cluster_key):
                    logg.warning(
                        "Ignoring `cluster_key` when plotting continuous observations in the same plot"
                    )
                # kwargs["color"] = cluster_key  this results in a bug, cluster_key data is overwritten, will make a PR
            else:
                kwargs["color"] = color

            if probs.shape[1] == 1 and prop in (P.MACRO_MEMBER.s,
                                                P.TERM_PROBS.s):
                if "perc" not in kwargs:
                    logg.warning(
                        "Did not detect percentile for stationary distribution. Setting `perc=[0, 95]`"
                    )
                    kwargs["perc"] = [0, 95]
                kwargs["color"] = X
                kwargs.pop("color_gradients", None)

            scv.pl.scatter(
                self.adata,
                title=title,
                color_map=cmap,
                **_filter_kwargs(scv.pl.scatter, **kwargs),
            )
        elif mode == "time":
            scv.pl.scatter(
                self.adata,
                x=time,
                color_map=cmap,
                y=color,
                title=title,
                xlabel=[time_key] * len(title),
                ylabel=["probability"] * len(title),
                **_filter_kwargs(scv.pl.scatter, **kwargs),
            )
        else:
            raise ValueError(
                f"Invalid mode `{mode!r}`. Valid options are: `'embedding'` or `'time'`."
            )
Beispiel #5
0
    def _plot_discrete(
        self,
        data: pd.Series,
        prop: str,
        lineages: Optional[Union[str, Sequence[str]]] = None,
        cluster_key: Optional[str] = None,
        same_plot: bool = True,
        title: Optional[Union[str, List[str]]] = None,
        **kwargs,
    ) -> None:
        """
        Plot the states for each uncovered lineage.

        Parameters
        ----------
        lineages
            Plot only these lineages. If `None`, plot all lineages.
        cluster_key
            Key from :attr:`adata` ``.obs`` for plotting categorical observations.
        same_plot
            Whether to plot the lineages on the same plot or separately.
        title
            The title of the plot.
        %(basis)s
        kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        if data is None:
            raise RuntimeError(
                f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`.")
        if not is_categorical_dtype(data):
            raise TypeError(
                f"Expected property `.{prop}` to be categorical, found `{type(data).__name__!r}`."
            )
        if prop in (P.ABS_PROBS.s, P.TERM.s):
            colors = getattr(self, A.TERM_COLORS.v, None)
        elif prop == P.MACRO.v:
            colors = getattr(self, A.MACRO_COLORS.v, None)
        else:
            logg.debug("No colors found. Creating new ones")
            colors = _create_categorical_colors(len(data.cat.categories))
        colors = dict(zip(data.cat.categories, colors))

        if (
                lineages is not None
        ):  # these are states per-se, but I want to keep the arg names for dispatch the same
            if isinstance(lineages, str):
                lineages = [lineages]
            for state in lineages:
                if state not in data.cat.categories:
                    raise ValueError(
                        f"Invalid state `{state!r}`. Valid options are `{list(data.cat.categories)}`."
                    )
            data = data.copy()
            to_remove = list(set(data.cat.categories) - set(lineages))

            if len(to_remove) == len(data.cat.categories):
                raise RuntimeError(
                    "Nothing to plot because empty subset has been selected.")

            for state in to_remove:
                data[data == state] = np.nan
            data = data.cat.remove_categories(to_remove)

        if cluster_key is None:
            cluster_key = []
        elif isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        if not isinstance(cluster_key, list):
            cluster_key = list(cluster_key)

        same_plot = same_plot or len(data.cat.categories) == 1
        kwargs["legend_loc"] = kwargs.get("legend_loc", "on data")

        with RandomKeys(self.adata,
                        None if same_plot else len(data.cat.categories),
                        where="obs") as keys:
            if same_plot:
                key = keys[0]
                self.adata.obs[key] = data
                self.adata.uns[f"{key}_colors"] = [
                    colors[c] for c in data.cat.categories
                ]

                if title is None:
                    title = (
                        f"{prop.replace('_', ' ')} "
                        f"({Direction.BACKWARD if self.kernel.backward else Direction.FORWARD})"
                    )
                if isinstance(title, str):
                    title = [title]

                scv.pl.scatter(
                    self.adata,
                    title=cluster_key + title,
                    color=cluster_key + keys,
                    **_filter_kwargs(scv.pl.scatter, **kwargs),
                )
            else:
                for key, cat in zip(keys, data.cat.categories):
                    d = data.copy()
                    d[data != cat] = None
                    d = d.cat.set_categories([cat])

                    self.adata.obs[key] = d
                    self.adata.uns[f"{key}_colors"] = [colors[cat]]

                scv.pl.scatter(
                    self.adata,
                    color=cluster_key + keys,
                    title=(cluster_key + [
                        f"{_initial if self.kernel.backward else _terminal} state {c}"
                        for c in data.cat.categories
                    ]) if title is None else title,
                    **_filter_kwargs(scv.pl.scatter, **kwargs),
                )
Beispiel #6
0
    def _plot_continuous(
        self,
        probs: Optional[Lineage],
        prop: str,
        diff_potential: Optional[pd.Series] = None,
        lineages: Optional[Union[str, Iterable[str]]] = None,
        cluster_key: Optional[str] = None,
        mode: str = "embedding",
        time_key: str = "latent_time",
        show_dp: bool = True,
        title: Optional[str] = None,
        same_plot: bool = True,
        cmap: Union[str, mpl.colors.ListedColormap] = cm.viridis,
        **kwargs,
    ) -> None:
        """
        Plot continuous observations, such as lineages, in an embedding.

        Parameters
        ----------
        lineages
            Plot only these lineages. If `None`, plot all lineages.
        cluster_key
            Key from :paramref:`adata` ``.obs`` for plotting categorical observations.
        %(time_mode)s
        time_key
            Key from :paramref:`adata` ``.obs`` to use as a pseudotime ordering of the cells.
        title
            Either `None`, in which case titles are ``'{to, from} {terminal, initial} {state}'``,
            or an array of titles, one per lineage.
        same_plot
            Whether to plot the lineages on the same plot using color gradients when ``mode='embedding'``.
        cmap
            Colormap to use.
        **kwargs
            Keyword arguments for :func:`scvelo.pl.scatter`.

        Returns
        -------
        %(just_plots)s
        """

        if probs is None:
            raise RuntimeError(
                f"Compute `.{prop}` first as `.{F.COMPUTE.fmt(prop)}()`."
            )

        if isinstance(lineages, str):
            lineages = [lineages]

        if lineages is None:
            lineages = probs.names
            A = probs
        else:
            A = probs[lineages]

        if not len(lineages):
            raise RuntimeError(
                "Nothing to plot because empty subset has been selected."
            )

        prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD
        same_plot = same_plot and mode == "embedding"  # set this silently

        diff_potential = (
            [diff_potential.values]
            if show_dp
            and not same_plot
            and diff_potential is not None
            and probs.shape[1] > 1
            else []
        )

        A = A.copy()  # the below code modifies stuff inplace
        X = A.X  # list(A.T) behaves differently, because it's Lineage

        if X.shape[1] == 1:
            same_plot = (
                False  # because color_gradients for 1 state is buggy (looks empty)
            )
            # this is the case for only 1 recurrent class - all cells have prob. 1 of going there
            # however, matplotlib's plotting really picks up the slightest differences in the colormap, here we set
            # everything to one, if applicable
            if np.allclose(X, 1.0):
                X = np.ones_like(X)

        for col in X.T:
            mask = ~np.isclose(col, 1.0)
            # change the maximum value - the 1 is artificial and obscures the color scaling
            if np.sum(mask):
                max_not_one = np.max(col[mask])
                col[~mask] = max_not_one

        if mode == "time":
            if time_key not in self.adata.obs.keys():
                raise KeyError(f"Time key `{time_key!r}` not found in `adata.obs`.")
            time = self.adata.obs[time_key]
            if cluster_key is not None:
                logg.warning(
                    f"Cluster key `{cluster_key!r}` is ignored when `mode='time'`"
                )
            cluster_key = None

        color = list(X.T) + diff_potential
        if title is None:
            if same_plot:
                title = [
                    f"{prop.replace('_', ' ')} "
                    f"({DirectionPlot.BACKWARD if self.kernel.backward else Direction.FORWARD})"
                ]
            else:
                title = [f"{prefix} {lin}" for lin in lineages] + (
                    ["differentiation potential"] if diff_potential else []
                )
        elif isinstance(title, str):
            title = [title]

        if isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        elif cluster_key is None:
            cluster_key = []
        if not isinstance(cluster_key, list):
            cluster_key = list(cluster_key)

        if not same_plot:
            color = cluster_key + color
            title = cluster_key + title

        if mode == "embedding":
            if same_plot:
                kwargs["color_gradients"] = A
                if len(cluster_key):
                    logg.warning(
                        "Ignoring `cluster_key` when plotting probabilities in the same plot"
                    )
                # kwargs["color"] = cluster_key  this results in a bug, cluster_key data is overwritten, will make a PR
            else:
                kwargs["color"] = color

            if probs.shape[1] == 1 and prop in (P.META_PROBS.s, P.FIN_PROBS.s):
                if "perc" not in kwargs:
                    logg.warning(
                        "Did not detect percentile for stationary distribution. Setting `perc=[0, 95]`"
                    )
                    kwargs["perc"] = [0, 95]
                kwargs["color"] = X
                kwargs.pop("color_gradients", None)

            scv.pl.scatter(
                self.adata,
                title=title,
                color_map=cmap,
                **_filter_kwargs(scv.pl.scatter, **kwargs),
            )
        elif mode == "time":
            scv.pl.scatter(
                self.adata,
                x=time,
                color_map=cmap,
                y=color,
                title=title,
                xlabel=[time_key] * len(title),
                ylabel=["probability"] * len(title),
                **_filter_kwargs(scv.pl.scatter, **kwargs),
            )
        else:
            raise ValueError(
                f"Invalid mode `{mode!r}`. Valid options are: `'embedding'` or `'time'`."
            )