Exemplo n.º 1
0
    def _density_normalize(
            self, other: Union[np.ndarray,
                               spmatrix]) -> Union[np.ndarray, spmatrix]:
        """
        Density normalization by the underlying KNN graph.

        Parameters
        ----------
        other:
            Matrix to normalize.

        Returns
        -------
        :class:`np.ndarray` or :class:`scipy.sparse.spmatrix`
            Density normalized transition matrix.
        """

        logg.debug("Density-normalizing the transition matrix")

        q = np.asarray(self._conn.sum(axis=0))

        if not issparse(other):
            Q = np.diag(1.0 / q)
        else:
            Q = spdiags(1.0 / q, 0, other.shape[0], other.shape[0])

        return Q @ other @ Q
Exemplo n.º 2
0
    def write(self,
              fname: Union[str, Path],
              ext: Optional[str] = "pickle") -> None:
        """
        Serialize self to a file.

        Parameters
        ----------
        fname
            Filename where to save the object.
        ext
            Filename extension to use. If `None`, don't append any extension.

        Returns
        -------
        None
            Nothing, just writes itself to a file using :mod:`pickle`.
        """

        fname = str(fname)
        if ext is not None:
            if not ext.startswith("."):
                ext = "." + ext
            if not fname.endswith(ext):
                fname += ext

        logg.debug(f"Writing to `{fname}`")

        with open(fname, "wb") as fout:
            pickle.dump(self, fout)
Exemplo n.º 3
0
    def __init__(
        self,
        obj: Union[AnnData, np.ndarray, spmatrix, KernelExpression],
        key: Optional[str] = None,
        obsp_key: Optional[str] = None,
        write_to_adata: bool = True,
    ):
        if isinstance(obj, KernelExpression):
            self._kernel = obj
        elif isinstance(obj, (np.ndarray, spmatrix)):
            self._kernel = PrecomputedKernel(obj)
        elif isinstance(obj, AnnData):
            if obsp_key is None:
                raise ValueError(
                    "Please specify `obsp_key=...` when supplying an AnnData object."
                )
            elif obsp_key not in obj.obsp.keys():
                raise KeyError(f"Key `{obsp_key!r}` not found in `adata.obsp`.")
            self._kernel = PrecomputedKernel(obj.obsp[obsp_key], adata=obj)
        else:
            raise TypeError(
                f"Expected an object of type `KernelExpression`, `numpy.ndarray`, `scipy.sparse.spmatrix` "
                f"or `anndata.AnnData`, got `{type(obj).__name__!r}`."
            )

        if self.kernel.transition_matrix is None:
            logg.debug("Computing transition matrix using default parameters")
            self.kernel.compute_transition_matrix()

        if write_to_adata:
            self.kernel.write_to_adata(key=key)
Exemplo n.º 4
0
    def maybe_sanity_check(callbacks: Dict[str, Dict[str, Callable]]) -> None:
        if not perform_sanity_check:
            return

        from sklearn.svm import SVR

        logg.debug("Performing callback sanity checks")
        for gene in callbacks.keys():
            for lineage, cb in callbacks[gene].items():
                # create the model here because the callback can search the attribute
                dummy_model = SKLearnModel(adata, model=SVR())
                try:
                    model = cb(dummy_model, gene=gene, lineage=lineage, **kwargs)
                    assert model is dummy_model, (
                        "Creation of new models is not allowed. "
                        "Ensure that callback returns the same model."
                    )
                    assert (
                        model.prepared
                    ), "Model is not prepared. Ensure that callback calls `.prepare()`."
                    assert (
                        model._gene == gene
                    ), f"Callback modified the gene from `{gene!r}` to `{model._gene!r}`."
                    assert (
                        model._lineage == lineage
                    ), f"Callback modified the lineage from `{lineage!r}` to `{model._lineage!r}`."
                except Exception as e:
                    raise RuntimeError(
                        f"Callback validation failed for gene `{gene!r}` and lineage `{lineage!r}`."
                    ) from e
Exemplo n.º 5
0
    def _density_normalize(
            self, other: Union[np.ndarray,
                               spmatrix]) -> Union[np.ndarray, spmatrix]:
        """
        Density normalization by the underlying KNN graph.

        Parameters
        ----------
        other
            Matrix to normalize.

        Returns
        -------
        :class:`np.ndarray` or :class:`scipy.sparse.spmatrix`
            Density normalized transition matrix.

        Raises
        ------
        ValueError
            If KNN connectivities are not set.
        """
        if self._conn is None:
            raise ValueError(
                "Unable to density normalize the transition matrix "
                "because KNN connectivities are not set.")

        logg.debug("Density-normalizing the transition matrix")

        q = np.asarray(self._conn.sum(axis=0))
        Q = spdiags(1.0 / q, 0, other.shape[0], other.shape[0])

        return Q @ other @ Q
Exemplo n.º 6
0
def _read_graph_data(adata: AnnData, key: str) -> Union[np.ndarray, spmatrix]:
    """
    Read graph data from :mod:`anndata`.

    :module`AnnData` >=0.7 stores `(n_obs x n_obs)` matrices in `.obsp` rather than `.uns`.
    This is for backward compatibility.

    Parameters
    ----------
    adata
        Annotated data object.
    key
        Key from either ``adata.uns`` or ``adata.obsp``.

    Returns
    -------
    :class:`numpy.ndarray` or :class:`scipy.sparse.spmatrix`
        The graph data.
    """

    if hasattr(adata, "obsp") and adata.obsp is not None and key in adata.obsp.keys():
        logg.debug(f"Reading key `{key!r}` from `adata.obsp`")
        return adata.obsp[key]

    if key in adata.uns.keys():
        logg.debug(f"Reading key `{key!r}` from `adata.uns`")
        return adata.uns[key]

    raise KeyError(f"Unable to find key `{key!r}` in `adata.obsp` or `adata.uns`.")
Exemplo n.º 7
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):

        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        # choosing this instead of property because GPCCA doesn't have property for FIN_ABS_PROBS
        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                logg.debug(
                    f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new names")
                names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                    map(lambda c: isinstance(c, str) and is_color_like(c),
                        colors)):
                logg.debug(
                    f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new colors")
                colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Exemplo n.º 8
0
    def compute_transition_matrix(self, *args,
                                  **kwargs) -> "SimpleNaryExpression":
        """Compute and combine the transition matrices."""
        # must be done before, because the underlying expression don't have to be normed
        if isinstance(self, KernelSimpleAdd):
            self._maybe_recalculate_constants(Constant)
        elif isinstance(self, KernelAdaptiveAdd):
            self._maybe_recalculate_constants(ConstantMatrix)

        for kexpr in self:
            if kexpr._transition_matrix is None:
                if isinstance(kexpr, Kernel):
                    raise RuntimeError(
                        f"Kernel `{kexpr}` is uninitialized. "
                        f"Compute its transition matrix first as `.compute_transition_matrix()`."
                    )
                kexpr.compute_transition_matrix()
            elif isinstance(kexpr, Kernel):
                logg.debug(_LOG_USING_CACHE)

        self.transition_matrix = csr_matrix(
            self._fn([kexpr.transition_matrix for kexpr in self]))

        # only the top level expression and kernels will have condition number computed
        if self._parent is None:
            self._maybe_compute_cond_num()

        return self
Exemplo n.º 9
0
    def maybe_create_lineage(
        direction: Union[str, Direction], pretty_name: Optional[str] = None
    ):
        if isinstance(direction, Direction):
            lin_key = str(
                AbsProbKey.FORWARD
                if direction == Direction.FORWARD
                else AbsProbKey.BACKWARD
            )
        else:
            lin_key = direction

        pretty_name = "" if pretty_name is None else (pretty_name + " ")
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)

        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`")

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage names not found in `adata.uns[{names_key!r}]`, creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"    Lineage names are don't have the required length ({n_lineages}), creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("    Successfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                map(lambda c: is_color_like(c), adata.uns[colors_key])
            ):
                logg.warning(
                    f"    Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("    Successfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(
                adata.obsm[lin_key], names=names, colors=colors
            )
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`"
            )
Exemplo n.º 10
0
 def _remove_min_clusters(self, min_flow: float) -> None:
     logg.debug("Removing clusters with no incoming flow edges")
     columns = (self._flow.loc[(slice(None), self._cluster), :] >
                min_flow).any()
     columns = columns[columns].index
     if not len(columns):
         raise ValueError(
             "After removing clusters with no incoming flow edges, none remain."
         )
     self._flow = self._flow[columns]
Exemplo n.º 11
0
 def _set_or_debug(
     self, needle: str, haystack, attr: Optional[Union[str, PrettyEnum]] = None
 ) -> Optional[Any]:
     if isinstance(attr, PrettyEnum):
         attr = attr.s
     if needle in haystack:
         if attr is None:
             return haystack[needle]
         setattr(self, attr, haystack[needle])
     elif attr is not None:
         logg.debug(f"Unable to set attribute `.{attr}`, skipping")
Exemplo n.º 12
0
def _filter_models(
    models,
    return_models: bool = False,
    filter_all_failed: bool = True
) -> Tuple[_return_model_type, _return_model_type, Sequence[str],
           Sequence[str]]:
    def is_valid(x: Union[BaseModel, BulkRes]) -> bool:
        if return_models:
            assert isinstance(
                x, BaseModel
            ), f"Expected `BaseModel`, found `{type(x).__name__!r}`."
            return bool(x)

        return (x.x_test is not None and x.y_test is not None
                and np.all(np.isfinite(x.y_test)))

    modelmat = pd.DataFrame(models).T
    modelmask = modelmat.applymap(is_valid)
    to_keep = modelmask[modelmask.any(axis=1)]
    to_keep = to_keep.loc[:, to_keep.any(axis=0)].T

    filtered_models = {
        gene: {
            ln: models[gene][ln]
            for ln in (ln for ln in v.keys() if (
                is_valid(models[gene][ln]) if filter_all_failed else True))
        }
        for gene, v in to_keep.to_dict().items()
    }

    if not len(filtered_models):
        if not return_models:
            raise RuntimeError(
                "Fitting has failed for all gene/lineage combinations. "
                "Specify `return_models=True` for more information.")
        for ms in models.values():
            for model in ms.values():
                assert isinstance(
                    model, FailedModel
                ), f"Expected `FailedModel`, found `{type(model).__name__!r}`."
                model.reraise()

    if not np.all(modelmask.values):
        failed_models = modelmat.values[~modelmask.values]
        logg.warning(
            f"Unable to fit `{len(failed_models)}` models." +
            "" if return_models else
            "Consider specify `return_models=True` for further inspection.")
        logg.debug("The failed models were:\n`{}`".format("\n".join(
            f"    {m}" for m in failed_models)))

    # lineages is the max number of lineages
    return models, filtered_models, tuple(filtered_models.keys()), tuple(
        to_keep.index)
Exemplo n.º 13
0
    def _reuse_cache(self,
                     expected_params: Dict[str, Any],
                     *,
                     time: Optional[Any] = None) -> bool:
        if expected_params == self._params:
            assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG
            logg.debug(_LOG_USING_CACHE)
            logg.info("    Finish", time=time)
            return True

        self._params = expected_params
        return False
Exemplo n.º 14
0
 def _maybe_compute_cond_num(self):
     if self._compute_cond_num and self._cond_num is None:
         logg.debug(f"Computing condition number of `{repr(self)}`")
         self._cond_num = np.linalg.cond(self._transition_matrix.toarray(
         ) if issparse(self._transition_matrix) else self._transition_matrix
                                         )
         if self._cond_num > _cond_num_tolerance:
             logg.warning(
                 f"`{repr(self)}` may be ill-conditioned, its condition number is `{self._cond_num:.2e}`"
             )
         else:
             logg.info(f"Condition number is `{self._cond_num:.2e}`")
Exemplo n.º 15
0
    def _read_from_adata(self, **kwargs):
        super()._read_from_adata(**kwargs)

        # check whether velocities have been computed
        vkey = kwargs.pop("vkey", "velocity")
        if vkey not in self.adata.layers.keys():
            raise KeyError("Compute RNA velocity first as `scv.tl.velocity()`.")

        # restrict genes to a subset, i.e. velocity genes or user provided list
        gene_subset = kwargs.pop("gene_subset", None)
        subset = np.ones(self.adata.n_vars, bool)
        if gene_subset is not None:
            var_names_subset = self.adata.var_names.isin(gene_subset)
            subset &= var_names_subset if len(var_names_subset) > 0 else gene_subset
        elif f"{vkey}_genes" in self.adata.var.keys():
            subset &= np.array(self.adata.var[f"{vkey}_genes"].values, dtype=bool)

        # chose data representation to use for transcriptomic displacements
        xkey = kwargs.pop("xkey", "Ms")
        xkey = xkey if xkey in self.adata.layers.keys() else "spliced"

        # filter both the velocities and the gene expression profiles to the gene subset. Densify the matrices.
        X = np.array(
            self.adata.layers[xkey].A[:, subset]
            if issparse(self.adata.layers[xkey])
            else self.adata.layers[xkey][:, subset]
        )
        V = np.array(
            self.adata.layers[vkey].A[:, subset]
            if issparse(self.adata.layers[vkey])
            else self.adata.layers[vkey][:, subset]
        )

        # remove genes that have any Nan values (in both X and V)
        nans = np.isnan(np.sum(V, axis=0))
        if np.any(nans):
            X = X[:, ~nans]
            V = V[:, ~nans]

        # check the velocity parameters
        par_key = f"{vkey}_params"
        if par_key in self.adata.uns.keys():
            velocity_params = self.adata.uns[par_key]
        else:
            velocity_params = None
            logg.debug(
                f"Unable to load velocity parameters from `adata.uns[{par_key!r}]`"
            )

        # add to self
        self._velocity = V.astype(np.float64)
        self._gene_expression = X.astype(np.float64)
        self._velocity_params = velocity_params
Exemplo n.º 16
0
    def _get_n_states_from_minchi(
            self, n_states: Union[Tuple[int, int], List[int],
                                  Dict[str, int]]) -> int:
        if self._gpcca is None:
            raise RuntimeError(
                "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`."
            )

        if not isinstance(n_states, (dict, tuple, list)):
            raise TypeError(
                f"Expected `n_states` to be either `dict`, `tuple` or a `list`, "
                f"found `{type(n_states).__name__}`.")
        if len(n_states) != 2:
            raise ValueError(
                f"Expected `n_states` to be of size `2`, found `{len(n_states)}`."
            )

        if isinstance(n_states, dict):
            if "min" not in n_states or "max" not in n_states:
                raise KeyError(
                    f"Expected the dictionary to have `'min'` and `'max'` keys, "
                    f"found `{tuple(n_states.keys())}`.")
            minn, maxx = n_states["min"], n_states["max"]
        else:
            minn, maxx = n_states

        if minn > maxx:
            logg.debug(
                f"Swapping minimum and maximum because `{minn}` > `{maxx}`")
            minn, maxx = maxx, minn

        if minn <= 1:
            raise ValueError(f"Minimum value must be > `1`, found `{minn}`.")
        elif minn == 2:
            logg.warning(
                "In most cases, 2 clusters will always be optimal. "
                "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`"
            )
            minn = 3

        if minn >= maxx:
            maxx = minn + 1
            logg.debug(
                f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`"
            )

        logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`")

        return int(
            np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn,
                                                                   maxx))])
Exemplo n.º 17
0
def _load_dataset_from_url(fpath: Union[os.PathLike, str], url: str,
                           expected_shape: Tuple[int,
                                                 int], **kwargs) -> AnnData:

    fpath = str(fpath)
    if not fpath.endswith(".h5ad"):
        fpath += ".h5ad"

    if os.path.isfile(fpath):
        logg.debug(f"Loading dataset from `{fpath!r}`")
    else:
        logg.debug(f"Downloading dataset from `{url!r}` as `{fpath!r}`")

    dirname, _ = os.path.split(fpath)
    try:
        if not os.path.isdir(dirname):
            logg.debug(f"Creating directory `{dirname!r}`")
            os.makedirs(dirname, exist_ok=True)
    except OSError as e:
        logg.debug(f"Unable to create directory `{dirname!r}`. Reason `{e}`")

    kwargs.setdefault("sparse", True)
    kwargs.setdefault("cache", True)

    adata = read(fpath, backup_url=url, **kwargs)

    if adata.shape != expected_shape:
        raise ValueError(
            f"Expected `anndata.AnnData` object to have shape `{expected_shape}`, found `{adata.shape}`."
        )

    adata.var_names_make_unique()

    return adata
Exemplo n.º 18
0
    def _read_from_adata(self, **kwargs):
        super()._read_from_adata(**kwargs)

        time_key = kwargs.pop("time_key", "dpt_pseudotime")
        if time_key not in self.adata.obs.keys():
            raise KeyError(
                f"Could not find time key in `adata.obs[{time_key!r}]`.")

        self._pseudotime = np.array(self.adata.obs[time_key]).astype(_dtype)

        if np.any(np.isnan(self._pseudotime)):
            raise ValueError("Encountered NaN values in pseudotime.")

        logg.debug("Clipping the pseudotime to 0-1 range")
        self._pseudotime = np.clip(self._pseudotime, 0, 1)
Exemplo n.º 19
0
    def __init__(
        self,
        transition_matrix: Optional[Union[np.ndarray, spmatrix, str]] = None,
        adata: Optional[AnnData] = None,
        backward: bool = False,
        compute_cond_num: bool = False,
    ):
        from anndata import AnnData as _AnnData

        if transition_matrix is None:
            transition_matrix = _transition(
                Direction.BACKWARD if backward else Direction.FORWARD)
            logg.debug(
                f"Setting transition matrix key to `{transition_matrix!r}`")

        if isinstance(transition_matrix, str):
            if adata is None:
                raise ValueError(
                    "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None."
                )
            transition_matrix = _read_graph_data(adata, transition_matrix)

        if not isinstance(transition_matrix, (np.ndarray, spmatrix)):
            raise TypeError(
                f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, "
                f"found `{type(transition_matrix).__name__!r}`.")

        if transition_matrix.shape[0] != transition_matrix.shape[1]:
            raise ValueError(
                f"Expected transition matrix to be square, found `{transition_matrix.shape}`."
            )

        if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL):
            raise ValueError(
                "Not a valid transition matrix: not all rows sum to 1.")

        if adata is None:
            logg.warning("Creating empty `AnnData` object")
            adata = _AnnData(
                csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32))

        super().__init__(adata,
                         backward=backward,
                         compute_cond_num=compute_cond_num)
        self._transition_matrix = csr_matrix(transition_matrix)
        self._maybe_compute_cond_num()
Exemplo n.º 20
0
    def test_logfile(self, tmp_path, logging_state):
        settings.verbosity = Verbosity.hint

        io = StringIO()
        settings.logfile = io
        assert settings.logfile is io
        assert settings.logpath is None
        logg.error("test!")
        assert io.getvalue() == "ERROR: test!\n"

        p = tmp_path / "test.log"
        settings.logpath = p
        assert settings.logpath == p
        assert settings.logfile.name == str(p)
        logg.hint("test2")
        logg.debug("invisible")
        assert settings.logpath.read_text() == "--> test2\n"
Exemplo n.º 21
0
def _ensure_numeric_ordered(adata: AnnData, key: str) -> pd.Series:
    if key not in adata.obs.keys():
        raise KeyError(f"Unable to find data in `adata.obs[{key!r}]`.")

    exp_time = adata.obs[key].copy()
    if not is_numeric_dtype(np.asarray(exp_time)):
        try:
            exp_time = np.asarray(exp_time).astype(float)
        except ValueError as e:
            raise TypeError(
                f"Unable to convert `adata.obs[{key!r}]` of type `{infer_dtype(adata.obs[key])}` to `float`."
            ) from e

    if not is_categorical_dtype(exp_time):
        logg.debug(f"Converting `adata.obs[{key!r}]` to `categorical`")
        exp_time = np.asarray(exp_time)
        categories = sorted(set(exp_time[~np.isnan(exp_time)]))
        if len(categories) > 100:
            raise ValueError(
                f"Unable to convert `adata.obs[{key!r}]` to `categorical` since it "
                f"would create `{len(categories)}` categories."
            )
        exp_time = pd.Series(
            pd.Categorical(
                exp_time,
                categories=categories,
                ordered=True,
            )
        )

    if not exp_time.cat.ordered:
        logg.warning("Categories are not ordered. Using ascending order")
        exp_time.cat = exp_time.cat.as_ordered()

    exp_time = pd.Series(pd.Categorical(exp_time, ordered=True), index=adata.obs_names)
    if exp_time.isnull().any():
        raise ValueError("Series contains NaN value(s).")

    n_cats = len(exp_time.cat.categories)
    if n_cats < 2:
        raise ValueError(f"Expected to find at least `2` categories, found `{n_cats}`.")

    return exp_time
Exemplo n.º 22
0
    def write(self,
              fname: Union[str, Path],
              ext: Optional[str] = "pickle") -> None:
        """
        %(pickleable.full_desc)s

        Parameters
        ----------
        %(pickleable.parameters)s

        Returns
        -------
        %(pickleable.returns)s
        """  # noqa

        fname = str(fname)
        if ext is not None:
            if not ext.startswith("."):
                ext = "." + ext
            if not fname.endswith(ext):
                fname += ext

        logg.debug(f"Writing to `{fname}`")

        with open(fname, "wb") as fout:
            if version_info[:2] > (3, 6):
                pickle.dump(self, fout)
            else:
                # we need to use PrecomputedKernel because Python3.6 can't pickle Enums
                # and they are present in VelocityKernel
                logg.warning(
                    "Saving kernel as `cellrank.tl.kernels.PrecomputedKernel`")
                orig_kernel = self.kernel
                self._kernel = PrecomputedKernel(self.kernel)
                try:
                    pickle.dump(self, fout)
                except Exception as e:
                    raise e
                finally:
                    self._kernel = orig_kernel
Exemplo n.º 23
0
def _write_graph_data(
    adata: AnnData,
    data: Union[np.ndarray, spmatrix],
    key: str,
):
    """
    Write graph data to :mod:`AnnData`.

    :module`anndata` >=0.7 stores `(n_obs x n_obs)` matrices in `.obsp` rather than `.uns`.
    This is for backward compatibility.

    Parameters
    ----------
    adata
        Annotated data object.
    data
        The graph data we want to write.
    key
        Key from either ``adata.uns`` or `adata.obsp``.

    Returns
    --------
    None
        Nothing, just writes the data.
    """

    try:
        adata.obsp[key] = data
        write_to = "obsp"

        if data.shape[0] != data.shape[1]:
            logg.warning(
                f"`adata.obsp` attribute should only contain square matrices, found shape `{data.shape}`"
            )

    except AttributeError:
        adata.uns[key] = data
        write_to = "uns"

    logg.debug(f"Writing graph data to `adata.{write_to}[{key!r}]`")
Exemplo n.º 24
0
def _read_graph_data(adata: AnnData, key: str) -> Union[np.ndarray, spmatrix]:
    """
    Read graph data from :mod:`anndata`.

    Parameters
    ----------
    adata
        Annotated data object.
    key
        Key in ``adata.obsp``.

    Returns
    -------
    :class:`numpy.ndarray` or :class:`scipy.sparse.spmatrix`
        The graph data.
    """

    logg.debug(f"Reading key `{key!r}` from `adata.obsp`")
    if key in adata.obsp.keys():
        return adata.obsp[key]

    raise KeyError(f"Unable to find key `{key!r}` in `adata.obsp`.")
Exemplo n.º 25
0
    def compute_transition_matrix(
        self, density_normalize: bool = True
    ) -> "ConnectivityKernel":
        """
        Compute transition matrix based on transcriptomic similarity.

        Uses symmetric, weighted KNN graph to compute symmetric transition matrix. The connectivities are computed
        using :func:`scanpy.pp.neighbors`. Depending on the parameters used there, they can be UMAP connectivities or
        gaussian-kernel-based connectivities with adaptive kernel width.

        Parameters
        ----------
        density_normalize
            Whether or not to use the underlying KNN graph for density normalization.

        Returns
        -------
        :class:`cellrank.tl.kernels.ConnectivityKernel`
            Makes :paramref:`transition_matrix` available.
        """

        start = logg.info("Computing transition matrix based on connectivities")

        params = {"dnorm": density_normalize}
        if params == self.params:
            assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG
            logg.debug(_LOG_USING_CACHE)
            logg.info("    Finish", time=start)
            return self

        self._params = params
        self._compute_transition_matrix(
            matrix=self._conn.copy(), density_normalize=density_normalize
        )

        logg.info("    Finish", time=start)

        return self
Exemplo n.º 26
0
    def _read_from_adata(
        self,
        conn_key: Optional[str] = "connectivities",
        read_conn: bool = True,
        **kwargs: Any,
    ) -> None:
        """
        Import the base-KNN graph and optionally check for symmetry and connectivity.

        Parameters
        ----------
        conn_key
            Key in :attr:`anndata.AnnData.uns` where connectivities are stored.
        read_conn
            Whether to read connectivities or set them to `None`. Useful when not exposing density normalization or
            when KNN connectivities are not used to compute the transition matrix.
        kwargs
            Additional keyword arguments.

        Returns
        -------
        Nothing, just sets :attr:`_conn`.
        """
        if not read_conn:
            self._conn = None
            return

        self._conn = _get_neighs(self.adata,
                                 mode="connectivities",
                                 key=conn_key).astype(_dtype)

        check_connectivity = kwargs.pop("check_connectivity", False)
        if check_connectivity:
            start = logg.debug("Checking the KNN graph for connectedness")
            if not _connected(self._conn):
                logg.warning("KNN graph is not connected", time=start)
            else:
                logg.debug("KNN graph is connected", time=start)

        start = logg.debug("Checking the KNN graph for symmetry")
        if not _symmetric(self._conn):
            logg.warning("KNN graph is not symmetric", time=start)
        else:
            logg.debug("KNN graph is symmetric", time=start)
Exemplo n.º 27
0
    def _read_from_adata(self, **kwargs):
        """Import the base-KNN graph and optionally check for symmetry and connectivity."""

        if not _has_neighs(self.adata):
            raise KeyError("Compute KNN graph first as `scanpy.pp.neighbors()`.")

        self._conn = _get_neighs(self.adata, "connectivities").astype(_dtype)

        check_connectivity = kwargs.pop("check_connectivity", False)
        if check_connectivity:
            start = logg.debug("Checking the KNN graph for connectedness")
            if not _connected(self._conn):
                logg.warning("KNN graph is not connected", time=start)
            else:
                logg.debug("KNN graph is connected", time=start)

        start = logg.debug("Checking the KNN graph for symmetry")
        if not _symmetric(self._conn):
            logg.warning("KNN graph is not symmetric", time=start)
        else:
            logg.debug("KNN graph is symmetric", time=start)
Exemplo n.º 28
0
    def compute_terminal_states(
        self,
        use: Optional[Union[int, Tuple[int], List[int], range]] = None,
        percentile: Optional[int] = 98,
        method: str = "kmeans",
        cluster_key: Optional[str] = None,
        n_clusters_kmeans: Optional[int] = None,
        n_neighbors: int = 20,
        resolution: float = 0.1,
        n_matches_min: Optional[int] = 0,
        n_neighbors_filtering: int = 15,
        basis: Optional[str] = None,
        n_comps: int = 5,
        scale: bool = False,
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
    ) -> None:
        """
        Find approximate recurrent classes of the Markov chain.

        Filter to obtain recurrent states in left eigenvectors.
        Cluster to obtain approximate recurrent classes in right eigenvectors.

        Parameters
        ----------
        use
            Which or how many first eigenvectors to use as features for clustering/filtering.
            If `None`, use the `eigengap` statistic.
        percentile
            Threshold used for filtering out cells which are most likely transient states. Cells which are in the
            lower ``percentile`` percent of each eigenvector will be removed from the data matrix.
        method
            Method to be used for clustering. Must be one of `'louvain'`, `'leiden'` or `'kmeans'`.
        cluster_key
            If a key to cluster labels is given, :attr:`{fs}` will get associated with these for naming and colors.
        n_clusters_kmeans
            If `None`, this is set to ``use + 1``.
        n_neighbors
            If we use `'louvain'` or `'leiden'` for clustering cells, we need to build a KNN graph.
            This is the :math:`K` parameter for that, the number of neighbors for each cell.
        resolution
            Resolution parameter for `'louvain'` or `'leiden'` clustering. Should be chosen relatively small.
        n_matches_min
            Filters out cells which don't have at least n_matches_min neighbors from the same class.
            This filters out some cells which are transient but have been misassigned.
        n_neighbors_filtering
            Parameter for filtering cells. Cells are filtered out if they don't have at least ``n_matches_min``
            neighbors among their ``n_neighbors_filtering`` nearest cells.
        basis
            Key from :paramref`adata` ``.obsm`` to be used as additional features for the clustering.
        n_comps
            Number of embedding components to be use when ``basis`` is not `None`.
        scale
            Scale to z-scores. Consider using this if appending embedding to features.
        %(en_cutoff_p_thresh)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`{fsp}`
                - :attr:`{fs}`
        """

        def _compute_macrostates_prob() -> Series:
            """Compute a global score of being an approximate recurrent class."""

            # get the truncated eigendecomposition
            V, evals = eig["V_l"].real[:, use], eig["D"].real[use]

            # shift and scale
            V_pos = np.abs(V)
            V_shifted = V_pos - np.min(V_pos, axis=0)
            V_scaled = V_shifted / np.max(V_shifted, axis=0)

            # check the ranges are correct
            assert np.allclose(np.min(V_scaled, axis=0), 0), "Lower limit it not zero."
            assert np.allclose(np.max(V_scaled, axis=0), 1), "Upper limit is not one."

            # further scale by the eigenvalues
            V_eigs = V_scaled / evals

            # sum over cols and scale
            c_ = np.sum(V_eigs, axis=1)
            c = c_ / np.max(c_)

            return Series(c, index=self.adata.obs_names)

        def check_use(use) -> List[int]:
            if method not in ["kmeans", "louvain", "leiden"]:
                raise ValueError(
                    f"Invalid method `{method!r}`. Valid options are `'louvain'`, `'leiden'` or `'kmeans'`."
                )

            if use is None:
                use = eig["eigengap"] + 1  # add one b/c indexing starts at 0
            if isinstance(use, int):
                use = list(range(use))
            elif not isinstance(use, (tuple, list, range)):
                raise TypeError(
                    f"Argument `use` must be either `int`, `tuple`, `list` or `range`, "
                    f"found `{type(use).__name__!r}`."
                )
            else:
                if not all(map(lambda u: isinstance(u, int), use)):
                    raise TypeError("Not all values in `use` argument are integers.")
            use = list(use)

            if len(use) == 0:
                raise ValueError(
                    f"Number of eigenvector must be larger than `0`, found `{len(use)}`."
                )

            muse = max(use)
            if muse >= eig["V_l"].shape[1] or muse >= eig["V_r"].shape[1]:
                raise ValueError(
                    f"Maximum specified eigenvector `{muse}` is larger "
                    f'than the number of computed eigenvectors `{eig["V_l"].shape[1]}`. '
                    f"Use `.compute_eigendecomposition(k={muse})` to recompute the eigendecomposition."
                )

            return use

        eig = self._get(P.EIG)
        if eig is None:
            raise RuntimeError(
                "Compute eigendecomposition first as `.compute_eigendecomposition()`."
            )
        use = check_use(use)

        start = logg.info("Computing approximate recurrent classes")
        # we check for complex values only in the left, that's okay because the complex pattern
        # will be identical for left and right
        V_l, V_r = eig["V_l"][:, use], eig["V_r"].real[:, use]
        V_l = _complex_warning(V_l, use, use_imag=False)

        # compute a rc probability
        logg.debug("Computing probabilities of approximate recurrent classes")
        self._set(A.TERM_PROBS, _compute_macrostates_prob())

        # retrieve embedding and concatenate
        if basis is not None:
            bkey = f"X_{basis}"
            if bkey not in self.adata.obsm.keys():
                raise KeyError(f"Basis key `{bkey!r}` not found in `adata.obsm`")

            X_em = self.adata.obsm[bkey][:, :n_comps]
            X = np.concatenate([V_r, X_em], axis=1)
        else:
            logg.debug("Basis is `None`. Setting X equal to the right eigenvectors")
            X = V_r

        # filter out cells which are in the lowest q percentile in abs value in each eigenvector
        if percentile is not None:
            logg.debug("Filtering out cells according to percentile")
            if percentile < 0 or percentile > 100:
                raise ValueError(
                    f"Percentile must be in interval `[0, 100]`, found `{percentile}`."
                )
            cutoffs = np.percentile(np.abs(V_l), percentile, axis=0)
            ixs = np.sum(np.abs(V_l) < cutoffs, axis=1) < V_l.shape[1]
            X = X[ixs, :]

        # scale
        if scale:
            X = zscore(X, axis=0)

        # cluster X
        if method == "kmeans" and n_clusters_kmeans is None:
            n_clusters_kmeans = len(use) + (percentile is None)
            if X.shape[0] < n_clusters_kmeans:
                raise ValueError(
                    f"Filtering resulted in only {X.shape[0]} cell(s), insufficient to cluster into "
                    f"`{n_clusters_kmeans}` clusters. Consider decreasing the value of `percentile`."
                )

        logg.debug(
            f"Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering"
        )
        labels = _cluster_X(
            X,
            method=method,
            n_clusters=n_clusters_kmeans,
            n_neighbors=n_neighbors,
            resolution=resolution,
        )

        # fill in the labels in case we filtered out cells before
        if percentile is not None:
            rc_labels = np.repeat(None, self.adata.n_obs)
            rc_labels[ixs] = labels
        else:
            rc_labels = labels

        rc_labels = Series(rc_labels, index=self.adata.obs_names, dtype="category")
        rc_labels.cat.categories = list(rc_labels.cat.categories.astype("str"))

        # filtering to get rid of some of the left over transient states
        if n_matches_min > 0:
            logg.debug(f"Filtering according to `n_matches_min={n_matches_min}`")
            distances = _get_connectivities(
                self.adata, mode="distances", n_neighbors=n_neighbors_filtering
            )
            rc_labels = _filter_cells(
                distances, rc_labels=rc_labels, n_matches_min=n_matches_min
            )

        self.set_terminal_states(
            labels=rc_labels,
            cluster_key=cluster_key,
            en_cutoff=en_cutoff,
            p_thresh=p_thresh,
            add_to_existing=False,
            time=start,
        )
Exemplo n.º 29
0
def cluster_lineage(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineage: str,
    backward: bool = False,
    time_range: _time_range_type = None,
    clusters: Optional[Sequence[str]] = None,
    n_points: int = 200,
    time_key: str = "latent_time",
    norm: bool = True,
    recompute: bool = False,
    callback: _callback_type = None,
    ncols: int = 3,
    sharey: Union[str, bool] = False,
    key: Optional[str] = None,
    random_state: Optional[int] = None,
    use_leiden: bool = False,
    show_progress_bar: bool = True,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}),
    neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}),
    clustering_kwargs: Dict = MappingProxyType({}),
    return_models: bool = False,
    **kwargs,
) -> Optional[_return_model_type]:
    """
    Cluster gene expression trends within a lineage and plot the clusters.

    This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive
    development along a given lineage. Consider running this function on a subset of genes which are potential
    lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineage
        Name of the lineage for which to cluster the genes.
    %(backward)s
    %(time_ranges)s
    clusters
        Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when
        plotting previously computed clusters.
    n_points
        Number of points used for prediction.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    norm
        Whether to z-normalize each trend to have zero mean, unit variance.
    recompute
        If `True`, recompute the clustering, otherwise try to find already existing one.
    %(model_callback)s
    ncols
        Number of columns for the plot.
    sharey
        Whether to share y-axis across multiple plots.
    key
        Key in ``adata.uns`` where to save the results. If `None`, it will be saved as ``lineage_{lineage}_trend`` .
    random_state
        Random seed for reproducibility.
    use_leiden
        Whether to use :func:`scanpy.tl.leiden` for clustering or :func:`scanpy.tl.louvain`.
    %(parallel)s
    %(plotting)s
    pca_kwargs
        Keyword arguments for :func:`scanpy.pp.pca`.
    neighbors_kwargs
        Keyword arguments for :func:`scanpy.pp.neighbors`.
    clustering_kwargs
        Keyword arguments for :func:`scanpy.tl.louvain` or :func:`scanpy.tl.leiden`.
    %(return_models)s
    **kwargs:
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s

        Also updates ``adata.uns`` with the following:

            - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of
              shape `(n_genes, n_points)` containing the clustered genes.
    """

    import scanpy as sc
    from anndata import AnnData as _AnnData

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    _ = adata.obsm[lineage_key][lineage]

    genes = _unique_order_preserving(genes)
    _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False))

    if key is None:
        key = f"lineage_{lineage}_trend"

    if recompute or key not in adata.uns:
        kwargs["backward"] = backward
        kwargs["time_key"] = time_key
        kwargs["n_test_points"] = n_points
        models = _create_models(model, genes, [lineage])
        all_models, models, genes, _ = _fit_bulk(
            models,
            _create_callbacks(adata, callback, genes, [lineage], **kwargs),
            genes,
            lineage,
            time_range,
            return_models=True,  # always return (better error messages)
            filter_all_failed=True,
            parallel_kwargs={
                "show_progress_bar": show_progress_bar,
                "n_jobs": _get_n_cores(n_jobs, len(genes)),
                "backend": _get_backend(models, backend),
            },
            **kwargs,
        )

        # `n_genes, n_test_points`
        trends = np.vstack(
            [model[lineage].y_test for model in models.values()]).T

        if norm:
            logg.debug("Normalizing trends")
            _ = StandardScaler(copy=False).fit_transform(trends)

        trends = _AnnData(trends.T)
        trends.obs_names = genes

        # sanity check
        if trends.n_obs != len(genes):
            raise RuntimeError(
                f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`."
            )
        if trends.n_vars != n_points:
            raise RuntimeError(
                f"Expected to find `{n_points}` points, found `{trends.n_vars}`."
            )

        random_state = np.random.mtrand.RandomState(random_state).randint(
            2**16)

        pca_kwargs = dict(pca_kwargs)
        pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1)
        pca_kwargs.setdefault("random_state", random_state)
        sc.pp.pca(trends, **pca_kwargs)

        neighbors_kwargs = dict(neighbors_kwargs)
        neighbors_kwargs.setdefault("random_state", random_state)
        sc.pp.neighbors(trends, **neighbors_kwargs)

        clustering_kwargs = dict(clustering_kwargs)
        clustering_kwargs["key_added"] = "clusters"
        clustering_kwargs.setdefault("random_state", random_state)
        try:
            if use_leiden:
                sc.tl.leiden(trends, **clustering_kwargs)
            else:
                sc.tl.louvain(trends, **clustering_kwargs)
        except ImportError as e:
            logg.warning(str(e))
            if use_leiden:
                sc.tl.louvain(trends, **clustering_kwargs)
            else:
                sc.tl.leiden(trends, **clustering_kwargs)

        logg.info(f"Saving data to `adata.uns[{key!r}]`")
        adata.uns[key] = trends
    else:
        all_models = None
        logg.info(f"Loading data from `adata.uns[{key!r}]`")
        trends = adata.uns[key]

    if "clusters" not in trends.obs:
        raise KeyError(
            "Unable to find the clustering in `trends.obs['clusters']`.")

    if clusters is None:
        clusters = trends.obs["clusters"].cat.categories
    for c in clusters:
        if c not in trends.obs["clusters"].cat.categories:
            raise ValueError(
                f"Invalid cluster name `{c!r}`. "
                f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`."
            )

    nrows = int(np.ceil(len(clusters) / ncols))
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(ncols * 10, nrows * 10) if figsize is None else figsize,
        sharey=sharey,
        dpi=dpi,
    )

    if not isinstance(axes, Iterable):
        axes = [axes]
    axes = np.ravel(axes)

    j = 0
    for j, (ax, c) in enumerate(zip(axes, clusters)):  # noqa
        data = trends[trends.obs["clusters"] == c].X
        mean, sd = np.mean(data, axis=0), np.var(data, axis=0)
        sd = np.sqrt(sd)

        for i in range(data.shape[0]):
            ax.plot(data[i], color="gray", lw=0.5)

        ax.plot(mean, lw=2, color="black")
        ax.plot(mean - sd, lw=1.5, color="black", linestyle="--")
        ax.plot(mean + sd, lw=1.5, color="black", linestyle="--")
        ax.fill_between(range(len(mean)),
                        mean - sd,
                        mean + sd,
                        color="black",
                        alpha=0.1)

        ax.set_title(f"Cluster {c}")
        ax.set_xticks([])

        if not sharey:
            ax.set_yticks([])

    for j in range(j + 1, len(axes)):
        axes[j].remove()

    if save is not None:
        save_fig(fig, save)

    if return_models:
        return all_models
Exemplo n.º 30
0
def heatmap(
    adata: AnnData,
    model: _input_model_type,
    genes: Sequence[str],
    lineages: Optional[Union[str, Sequence[str]]] = None,
    backward: bool = False,
    mode: str = HeatmapMode.LINEAGES.s,
    time_key: str = "latent_time",
    time_range: Optional[Union[_time_range_type,
                               List[_time_range_type]]] = None,
    callback: _callback_type = None,
    cluster_key: Optional[Union[str, Sequence[str]]] = None,
    show_absorption_probabilities: bool = False,
    cluster_genes: bool = False,
    keep_gene_order: bool = False,
    scale: bool = True,
    n_convolve: Optional[int] = 5,
    show_all_genes: bool = False,
    cbar: bool = True,
    lineage_height: float = 0.33,
    fontsize: Optional[float] = None,
    xlabel: Optional[str] = None,
    cmap: mcolors.ListedColormap = cm.viridis,
    dendrogram: bool = True,
    return_genes: bool = False,
    return_models: bool = False,
    n_jobs: Optional[int] = 1,
    backend: str = _DEFAULT_BACKEND,
    show_progress_bar: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    save: Optional[Union[str, Path]] = None,
    **kwargs,
) -> Optional[Union[Dict[str, pd.DataFrame], Tuple[_return_model_type, Dict[
        str, pd.DataFrame]]]]:
    """
    Plot a heatmap of smoothed gene expression along specified lineages.

    Parameters
    ----------
    %(adata)s
    %(model)s
    %(genes)s
    lineages
        Names of the lineages for which to plot. If `None`, plot all lineages.
    %(backward)s
    mode
        Valid options are:

            - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``.
            - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``.
    time_key
        Key in ``adata.obs`` where the pseudotime is stored.
    %(time_ranges)s
    %(model_callback)s
    cluster_key
        Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap.
        Only available when ``mode={m.LINEAGES.s!r}``.
    show_absorption_probabilities
        Whether to also plot absorption probabilities alongside the smoothed expression.
        Only available when ``mode={m.LINEAGES.s!r}``.
    cluster_genes
        Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``.
    keep_gene_order
        Whether to keep the gene order for later lineages after the first was sorted.
        Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``.
    scale
        Whether to normalize the gene expression `0-1` range.
    n_convolve
        Size of the convolution window when smoothing absorption probabilities.
    show_all_genes
        Whether to show all genes on y-axis.
    cbar
        Whether to show the colorbar.
    lineage_height
        Height of a bar when ``mode={m.GENES.s!r}``.
    fontsize
        Size of the title's font.
    xlabel
        Label on the x-axis. If `None`, it is determined based on ``time_key``.
    cmap
        Colormap to use when visualizing the smoothed expression.
    dendrogram
        Whether to show dendrogram when ``cluster_genes=True``.
    return_genes
        Whether to return the sorted or clustered genes. Only available when ``mode={m.LINEAGES.s!r}``.
    %(return_models)s
    %(parallel)s
    %(plotting)s
    kwargs
        Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`.

    Returns
    -------
    %(plots_or_returns_models)s
    :class:`pandas.DataFrame`
        If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame`
        containing the clustered or sorted genes.
    """

    import seaborn as sns

    def find_indices(series: pd.Series, values) -> Tuple[Any]:
        def find_nearest(array: np.ndarray, value: float) -> int:
            ix = np.searchsorted(array, value, side="left")
            if ix > 0 and (ix == len(array) or fabs(value - array[ix - 1]) <
                           fabs(value - array[ix])):
                return ix - 1
            return ix

        series = series[np.argsort(series.values)]

        return tuple(series[[find_nearest(series.values, v)
                             for v in values]].index)

    def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray:
        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        lin = adata[ixs, :].obsm[lineage_key][lname]

        lin = lin.X.copy().squeeze()
        if n_convolve is not None:
            lin = convolve(lin,
                           np.ones(n_convolve) / n_convolve,
                           mode="nearest")

        return lin

    def create_col_colors(lname: str,
                          rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]:
        color = adata.obsm[lineage_key][lname].colors[0]
        lin = subset_lineage(lname, rng)

        h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
        end_color = mcolors.hsv_to_rgb([h, 1, v])

        lineage_cmap = mcolors.LinearSegmentedColormap.from_list(
            "lineage_cmap", ["#ffffff", end_color], N=len(rng))
        norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin))
        scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm)

        return (
            np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]),
            lineage_cmap,
            norm,
        )

    def create_col_categorical_color(cluster_key: str,
                                     rng: np.ndarray) -> np.ndarray:
        if not is_categorical_dtype(adata.obs[cluster_key]):
            raise TypeError(
                f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                f"found `{adata.obs[cluster_key].dtype.name!r}`.")

        color_key = f"{cluster_key}_colors"
        if color_key not in adata.uns:
            logg.warning(
                f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors"
            )
            colors = _create_categorical_colors(
                len(adata.obs[cluster_key].cat.categories))
            adata.uns[color_key] = colors
        else:
            colors = adata.uns[color_key]

        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))

        return np.array([
            mcolors.to_hex(mapper[v])
            for v in adata[ixs, :].obs[cluster_key].values
        ])

    def create_cbar(
        ax,
        x_delta: float,
        cmap: Cmap,
        norm: Norm,
        label: Optional[str] = None,
    ) -> Ax:
        cax = inset_axes(
            ax,
            width="1%",
            height="100%",
            loc="lower right",
            bbox_to_anchor=(x_delta, 0, 1, 1),
            bbox_transform=ax.transAxes,
        )

        _ = mpl.colorbar.ColorbarBase(
            cax,
            cmap=cmap,
            norm=norm,
            label=label,
            ticks=np.linspace(norm.vmin, norm.vmax, 5),
        )

        return cax

    @valuedispatch
    def _plot_heatmap(_mode: HeatmapMode) -> Fig:
        pass

    @_plot_heatmap.register(HeatmapMode.GENES)
    def _() -> Tuple[Fig, None]:
        def color_fill_rec(ax,
                           xs,
                           y1,
                           y2,
                           colors=None,
                           cmap=cmap,
                           **kwargs) -> None:
            colors = colors if cmap is None else cmap(colors)

            x = 0
            for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)):
                dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2])
                ax.add_patch(
                    plt.Rectangle((x, y1),
                                  dx,
                                  y2 - y1,
                                  color=color,
                                  ec=color,
                                  **kwargs))

            ax.plot(x, y2, lw=0)

        fig, axes = plt.subplots(
            nrows=len(genes) + show_absorption_probabilities,
            figsize=(12, len(genes) + len(lineages) * lineage_height)
            if figsize is None else figsize,
            dpi=dpi,
            constrained_layout=True,
        )

        if not isinstance(axes, Iterable):
            axes = [axes]
        axes = np.ravel(axes)

        if show_absorption_probabilities:
            data["absorption probability"] = data[next(iter(data.keys()))]

        for ax, (gene, models) in zip(axes, data.items()):
            if scale:
                vmin, vmax = 0, 1
            else:
                c = np.array([m.y_test for m in models.values()])
                vmin, vmax = np.nanmin(c), np.nanmax(c)

            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

            ix = 0
            ys = [ix]

            if gene == "absorption probability":
                norm = mcolors.Normalize(vmin=0, vmax=1)
                for ln, x in ((ln, m.x_test) for ln, m in models.items()):
                    y = np.ones_like(x)
                    c = subset_lineage(ln, x.squeeze())

                    color_fill_rec(ax,
                                   x,
                                   y * ix,
                                   y * (ix + lineage_height),
                                   colors=norm(c))

                    ix += lineage_height
                    ys.append(ix)
            else:
                for x, c in ((m.x_test, m.y_test) for m in models.values()):
                    y = np.ones_like(x)
                    c = _min_max_scale(c) if scale else c

                    color_fill_rec(ax,
                                   x,
                                   y * ix,
                                   y * (ix + lineage_height),
                                   colors=norm(c))

                    ix += lineage_height
                    ys.append(ix)

            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.min(xs), np.max(xs)
            ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS))

            ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2)
            ax.spines["left"].set_position(
                ("data", 0)
            )  # move the left spine to the rectangles to get nicer yticks
            ax.set_yticklabels(models.keys(), ha="right")

            ax.set_title(gene, fontdict={"fontsize": fontsize})
            ax.set_ylabel("lineage")

            for pos in ["top", "bottom", "left", "right"]:
                ax.spines[pos].set_visible(False)

            if cbar:
                cax, _ = mpl.colorbar.make_axes(ax)
                _ = mpl.colorbar.ColorbarBase(
                    cax,
                    ticks=np.linspace(vmin, vmax, 5),
                    norm=norm,
                    cmap=cmap,
                    label="value" if gene == "absorption probability" else
                    "scaled expression" if scale else "expression",
                )

            ax.tick_params(
                top=False,
                bottom=False,
                left=True,
                right=False,
                labelleft=True,
                labelbottom=False,
            )

        ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f"))
        ax.tick_params(
            top=False,
            bottom=True,
            left=True,
            right=False,
            labelleft=True,
            labelbottom=True,
        )
        ax.set_xlabel(xlabel)

        return fig, None

    @_plot_heatmap.register(HeatmapMode.LINEAGES)
    def _() -> Tuple[List[Fig], pd.DataFrame]:
        data_t = defaultdict(dict)  # transpose
        for gene, lns in data.items():
            for ln, y in lns.items():
                data_t[ln][gene] = y

        figs = []
        gene_order = None
        sorted_genes = pd.DataFrame() if return_genes else None

        for lname, models in data_t.items():
            xs = np.array([m.x_test for m in models.values()])
            x_min, x_max = np.nanmin(xs), np.nanmax(xs)

            df = pd.DataFrame([m.y_test for m in models.values()],
                              index=models.keys())
            df.index.name = "genes"

            if not cluster_genes:
                if gene_order is not None:
                    df = df.loc[gene_order]
                else:
                    max_sort = np.argsort(
                        np.argmax(df.apply(_min_max_scale, axis=1).values,
                                  axis=1))
                    df = df.iloc[max_sort, :]
                    if keep_gene_order:
                        gene_order = df.index

            cat_colors = None
            if cluster_key is not None:
                cat_colors = np.stack(
                    [
                        create_col_categorical_color(
                            c, np.linspace(x_min, x_max, df.shape[1]))
                        for c in cluster_key
                    ],
                    axis=0,
                )

            if show_absorption_probabilities:
                col_colors, col_cmap, col_norm = create_col_colors(
                    lname, np.linspace(x_min, x_max, df.shape[1]))
                if cat_colors is not None:
                    col_colors = np.vstack([cat_colors, col_colors[None, :]])
            else:
                col_colors, col_cmap, col_norm = cat_colors, None, None

            row_cluster = cluster_genes and df.shape[0] > 1
            show_clust = row_cluster and dendrogram

            g = sns.clustermap(
                df,
                cmap=cmap,
                figsize=(10, min(len(genes) / 8 +
                                 1, 10)) if figsize is None else figsize,
                xticklabels=False,
                row_cluster=row_cluster,
                col_colors=col_colors,
                colors_ratio=0,
                col_cluster=False,
                cbar_pos=None,
                yticklabels=show_all_genes or "auto",
                standard_scale=0 if scale else None,
            )

            if cbar:
                cax = create_cbar(
                    g.ax_heatmap,
                    0.1,
                    cmap=cmap,
                    norm=mcolors.Normalize(
                        vmin=0 if scale else np.min(df.values),
                        vmax=1 if scale else np.max(df.values),
                    ),
                    label="scaled expression" if scale else "expression",
                )
                g.fig.add_axes(cax)

                if col_cmap is not None and col_norm is not None:
                    cax = create_cbar(
                        g.ax_heatmap,
                        0.25,
                        cmap=col_cmap,
                        norm=col_norm,
                        label="absorption probability",
                    )
                    g.fig.add_axes(cax)

            if g.ax_col_colors:
                main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap)
                n_bars = show_absorption_probabilities + (
                    len(cluster_key) if cluster_key is not None else 0)
                _set_ax_height_to_cm(
                    g.fig,
                    g.ax_col_colors,
                    height=min(
                        5,
                        max(n_bars * main_bbox.height / len(df),
                            0.25 * n_bars)),
                )
                g.ax_col_colors.set_title(lname,
                                          fontdict={"fontsize": fontsize})
            else:
                g.ax_heatmap.set_title(lname, fontdict={"fontsize": fontsize})

            g.ax_col_dendrogram.set_visible(
                False)  # gets rid of top free space

            g.ax_heatmap.yaxis.tick_left()
            g.ax_heatmap.yaxis.set_label_position("right")

            g.ax_heatmap.set_xlabel(xlabel)
            g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS))
            g.ax_heatmap.set_xticklabels(
                list(
                    map(lambda n: round(n, 3),
                        np.linspace(x_min, x_max, _N_XTICKS))))

            if show_clust:
                # robustly show dendrogram, because gene names can be long
                g.ax_row_dendrogram.set_visible(True)
                dendro_box = g.ax_row_dendrogram.get_position()

                pad = 0.005
                bb = g.ax_heatmap.yaxis.get_tightbbox(
                    g.fig.canvas.get_renderer()).transformed(
                        g.fig.transFigure.inverted())

                dendro_box.x0 = bb.x0 - dendro_box.width - pad
                dendro_box.x1 = bb.x0 - pad

                g.ax_row_dendrogram.set_position(dendro_box)
            else:
                g.ax_row_dendrogram.set_visible(False)

            if return_genes:
                sorted_genes[lname] = (df.index[g.dendrogram_row.reordered_ind]
                                       if hasattr(g, "dendrogram_row")
                                       and g.dendrogram_row is not None else
                                       df.index)

            figs.append(g)

        return figs, sorted_genes

    mode = HeatmapMode(mode)

    lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD)
    if lineage_key not in adata.obsm:
        raise KeyError(
            f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.")

    if lineages is None:
        lineages = adata.obsm[lineage_key].names
    elif isinstance(lineages, str):
        lineages = [lineages]
    lineages = _unique_order_preserving(lineages)

    _ = adata.obsm[lineage_key][lineages]

    if cluster_key is not None:
        if isinstance(cluster_key, str):
            cluster_key = [cluster_key]
        cluster_key = _unique_order_preserving(cluster_key)

    if isinstance(genes, str):
        genes = [genes]
    genes = _unique_order_preserving(genes)
    _check_collection(adata,
                      genes,
                      "var_names",
                      use_raw=kwargs.get("use_raw", False))

    kwargs["backward"] = backward
    kwargs["time_key"] = time_key
    models = _create_models(model, genes, lineages)
    all_models, data, genes, lineages = _fit_bulk(
        models,
        _create_callbacks(adata, callback, genes, lineages, **kwargs),
        genes,
        lineages,
        time_range,
        return_models=True,  # always return (better error messages)
        filter_all_failed=True,
        parallel_kwargs={
            "show_progress_bar": show_progress_bar,
            "n_jobs": _get_n_cores(n_jobs, len(genes)),
            "backend": _get_backend(models, backend),
        },
        **kwargs,
    )

    xlabel = time_key if xlabel is None else xlabel

    logg.debug(f"Plotting `{mode.s!r}` heatmap")
    fig, genes = _plot_heatmap(mode)

    if save is not None and fig is not None:
        if not isinstance(fig, Iterable):
            save_fig(fig, save)
        elif len(fig) == 1:
            save_fig(fig[0], save)
        else:
            for ln, f in zip(lineages, fig):
                save_fig(f, os.path.join(save, f"lineage_{ln}"))

    if return_genes and mode == HeatmapMode.LINEAGES:
        return (all_models, genes) if return_models else genes
    elif return_models:
        return all_models