Example #1
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):
        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                if isinstance(probs, Lineage):
                    names = probs.names
                else:
                    logg.warning(
                        f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new names"
                    )
                    names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                map(lambda c: isinstance(c, str) and is_color_like(c), colors)
            ):
                if isinstance(probs, Lineage):
                    colors = probs.colors
                else:
                    logg.warning(
                        f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new colors"
                    )
                    colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Example #2
0
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):

        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        # choosing this instead of property because GPCCA doesn't have property for FIN_ABS_PROBS
        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                logg.debug(
                    f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new names")
                names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                    map(lambda c: isinstance(c, str) and is_color_like(c),
                        colors)):
                logg.debug(
                    f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                    f"Creating new colors")
                colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
Example #3
0
    def maybe_create_lineage(
        direction: Union[str, Direction], pretty_name: Optional[str] = None
    ):
        if isinstance(direction, Direction):
            lin_key = str(
                AbsProbKey.FORWARD
                if direction == Direction.FORWARD
                else AbsProbKey.BACKWARD
            )
        else:
            lin_key = direction

        pretty_name = "" if pretty_name is None else (pretty_name + " ")
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)

        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`")

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage names not found in `adata.uns[{names_key!r}]`, creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"    Lineage names are don't have the required length ({n_lineages}), creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("    Successfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                map(lambda c: is_color_like(c), adata.uns[colors_key])
            ):
                logg.warning(
                    f"    Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("    Successfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(
                adata.obsm[lin_key], names=names, colors=colors
            )
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`"
            )
Example #4
0
    def _compute_one_macrostate(
        self,
        n_cells: int,
        cluster_key: Optional[str],
        en_cutoff: Optional[float],
        p_thresh: float,
    ) -> None:
        start = logg.warning(
            "For 1 macrostate, stationary distribution is computed")

        eig = self._get(P.EIG)
        if (eig is not None and "stationary_dist" in eig
                and eig["params"]["which"] == "LR"):
            stationary_dist = eig["stationary_dist"]
        else:
            self.compute_eigendecomposition(only_evals=False, which="LR")
            stationary_dist = self._get(P.EIG)["stationary_dist"]

        self._set_macrostates(
            memberships=stationary_dist[:, None],
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )
        self._set(
            A.MACRO_MEMBER,
            Lineage(
                stationary_dist,
                names=list(self._get(A.MACRO).cat.categories),
                colors=self._get(A.MACRO_COLORS),
            ),
        )

        # reset all the things
        for key in (
                A.ABS_PROBS,
                A.PRIME_DEG,
                A.SCHUR,
                A.SCHUR_MAT,
                A.COARSE_T,
                A.COARSE_STAT_D,
                A.COARSE_STAT_D,
        ):
            self._set(key.s, None)

        logg.info(
            f"Adding `.{P.MACRO_MEMBER}`\n        `.{P.MACRO}`\n    Finish",
            time=start,
        )
Example #5
0
    def _set_macrostates(
        self,
        memberships: np.ndarray,
        n_cells: Optional[int] = 30,
        cluster_key: str = "clusters",
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
        check_row_sums: bool = True,
    ) -> None:
        """
        Map fuzzy clustering to pre-computed annotations to get names and colors.

        Given the fuzzy clustering, we would like to select the most likely cells from each state and use these to
        give each state a name and a color by comparing with pre-computed, categorical cluster annotations.

        Parameters
        ----------
        memberships
            Fuzzy clustering.
        %(n_cells)s
        cluster_key
            Key from :attr:`adata` ``.obs`` to get reference cluster annotations.
        en_cutoff
            Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when
            one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost
            evenly across the reference clusters.
        p_thresh
            Only used to detect cell cycle stages. These have to be present in :attr:`adata` ``.obs`` as
            `'G2M_score'` and `'S_score'`.
        check_row_sums
            Check whether rows in `memberships` sum to `1`.

        Returns
        -------
        None
            Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors.
            Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state.
        """

        if n_cells is None:
            logg.debug("Setting the macrostates using macrostate assignment")

            # fmt: off
            max_assignment = np.argmax(memberships, axis=1)
            _macro_assignment = pd.Series(index=self.adata.obs_names,
                                          data=max_assignment,
                                          dtype="category")
            # sometimes, the assignment can have a missing category and the Lineage creation therefore fails
            # keep it as ints when `n_cells != None`
            _macro_assignment = _macro_assignment.cat.set_categories(
                list(range(memberships.shape[1])))
            macrostates = _macro_assignment.astype(str).astype(
                "category").copy()
            not_enough_cells = []
            # fmt: on
        else:
            logg.debug("Setting the macrostates using macrostates memberships")

            # select the most likely cells from each macrostate
            macrostates, not_enough_cells = self._create_states(
                memberships,
                n_cells=n_cells,
                check_row_sums=check_row_sums,
                return_not_enough_cells=True,
            )
            not_enough_cells = not_enough_cells.astype("str")

        # _set_categorical_labels creates the names, we still need to remap the group names
        orig_cats = macrostates.cat.categories
        self._set_categorical_labels(
            attr_key=A.MACRO.v,
            color_key=A.MACRO_COLORS.v,
            pretty_attr_key=P.MACRO.v,
            add_to_existing_error_msg=
            "Compute macrostates first as `.compute_macrostates()`.",
            categories=macrostates,
            cluster_key=cluster_key,
            en_cutoff=en_cutoff,
            p_thresh=p_thresh,
            add_to_existing=False,
        )

        name_mapper = dict(zip(orig_cats, self._get(P.MACRO).cat.categories))
        _print_insufficient_number_of_cells(
            [name_mapper.get(n, n) for n in not_enough_cells], n_cells)
        logg.debug(
            "Setting macrostates memberships based on GPCCA membership vectors"
        )

        self._set(
            A.MACRO_MEMBER,
            Lineage(
                memberships,
                names=list(macrostates.cat.categories),
                colors=self._get(A.MACRO_COLORS),
            ),
        )
Example #6
0
    def compute_lineage_drivers(
        self,
        lineages: Optional[Union[str, Sequence]] = None,
        method: str = TestMethod.FISCHER.s,
        cluster_key: Optional[str] = None,
        clusters: Optional[Union[str, Sequence]] = None,
        layer: str = "X",
        use_raw: bool = False,
        confidence_level: float = 0.95,
        n_perms: int = 1000,
        seed: Optional[int] = None,
        return_drivers: bool = True,
        **kwargs,
    ) -> Optional[pd.DataFrame]:
        """
        Compute driver genes per lineage.

        Correlates gene expression with lineage probabilities, for a given lineage and set of clusters.
        Often, it makes sense to restrict this to a set of clusters which are relevant for the specified lineages.

        Parameters
        ----------
        lineages
            Either a set of lineage names from :paramref:`absorption_probabilities` `.names` or `None`,
            in which case all lineages are considered.
        method
            Mode to use when calculating p-values and confidence intervals. Can be one of:

                - {tm.FISCHER.s!r} - use Fischer transformation [Fischer21]_.
                - {tm.PERM_TEST.s!r} - use permutation test.

        cluster_key
            Key from :paramref:`adata` ``.obs`` to obtain cluster annotations. These are considered for ``clusters``.
        clusters
            Restrict the correlations to these clusters.
        layer
            Key from :paramref:`adata` ``.layers``.
        use_raw
            Whether or not to use :paramref:`adata` ``.raw`` to correlate gene expression.
            If using a layer other than ``.X``, this must be set to `False`.
        confidence_level
            Confidence level for the confidence interval calculation. Must be in `[0, 1]`.
        n_perms
            Number of permutations to use when ``method={tm.PERM_TEST.s!r}``.
        seed
            Random seed when ``method={tm.PERM_TEST.s!r}``.
        return_drivers
            Whether to return the drivers. This also contains the lower and upper ``confidence_level`` confidence
            interval bounds.
        %(parallel)s

        Returns
        -------
        %(correlation_test.returns)s
            Only if ``return_drivers=True``.

        None
            Updates :paramref:`adata` ``.var`` or :paramref:`adata` ``.raw.var``, depending ``use_raw`` with:

                - ``'{{direction}} {{lineage}} corr'`` - the potential lineage drivers.
                - ``'{{direction}} {{lineage}} qval'`` - the corrected p-values.

            Updates the following fields:

                - :paramref:`{lin_drivers}` - same as the returned values.

        References
        ----------
        .. [Fischer21] Fisher, R. A. (1921),
            *On the “probable error” of a coefficient of correlation deduced from a small sample.*,
            `Metron 1 3–32 <http://hdl.handle.net/2440/15169>`__.
        """

        # check that lineage probs have been computed
        method = TestMethod(method)
        abs_probs = self._get(P.ABS_PROBS)
        prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD

        if abs_probs is None:
            raise RuntimeError(
                "Compute absorption probabilities first as `.compute_absorption_probabilities()`."
            )
        elif abs_probs.shape[1] == 1:
            logg.warning(
                "There is only 1 lineage present. Using the stationary distribution instead"
            )
            abs_probs = Lineage(
                self._get(P.TERM_PROBS).values,
                names=abs_probs.names,
                colors=abs_probs.colors,
            )

        # check all lin_keys exist in self.lin_names
        if isinstance(lineages, str):
            lineages = [lineages]
        if lineages is not None:
            _ = abs_probs[lineages]
        else:
            lineages = abs_probs.names

        if not len(lineages):
            raise ValueError("No lineages have been selected.")

        # use `cluster_key` and clusters to subset the data
        if clusters is not None:
            if cluster_key not in self.adata.obs.keys():
                raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.")
            if isinstance(clusters, str):
                clusters = [clusters]

            all_clusters = np.array(self.adata.obs[cluster_key].cat.categories)
            cluster_mask = np.array([name not in all_clusters for name in clusters])

            if any(cluster_mask):
                raise KeyError(
                    f"Clusters `{list(np.array(clusters)[cluster_mask])}` not found in "
                    f"`adata.obs[{cluster_key!r}]`."
                )
            subset_mask = np.in1d(self.adata.obs[cluster_key], clusters)
            adata_comp = self.adata[subset_mask]
            lin_probs = abs_probs[subset_mask, :]
        else:
            adata_comp = self.adata
            lin_probs = abs_probs

        # check that the layer exists, and that use raw is only used with layer X
        if layer != "X":
            if layer not in self.adata.layers:
                raise KeyError(f"Layer `{layer!r}` not found in `adata.layers`.")
            if use_raw:
                raise ValueError("For `use_raw=True`, layer must be 'X'.")
            data = adata_comp.layers[layer]
            var_names = adata_comp.var_names
        else:
            if use_raw and self.adata.raw is None:
                logg.warning("No raw attribute set. Using `.X` instead")
                use_raw = False
            data = adata_comp.raw.X if use_raw else adata_comp.X
            var_names = adata_comp.raw.var_names if use_raw else adata_comp.var_names

        start = logg.info(
            f"Computing correlations for lineages `{lineages}` restricted to clusters `{clusters}` in "
            f"layer `{layer}` with `use_raw={use_raw}`"
        )

        drivers = _correlation_test(
            data,
            lin_probs[lineages],
            gene_names=var_names,
            method=method,
            n_perms=n_perms,
            seed=seed,
            confidence_level=confidence_level,
            **kwargs,
        )
        self._set(A.LIN_DRIVERS, drivers)

        corrs, qvals = [f"{lin} corr" for lin in lineages], [
            f"{lin} qval" for lin in lineages
        ]
        if use_raw:
            self.adata.raw.var[[f"{prefix} {col}" for col in corrs]] = drivers[corrs]
            self.adata.raw.var[[f"{prefix} {col}" for col in qvals]] = drivers[qvals]
        else:
            self.adata.var[[f"{prefix} {col}" for col in corrs]] = drivers[corrs]
            self.adata.var[[f"{prefix} {col}" for col in qvals]] = drivers[qvals]

        field = "raw.var" if use_raw else "var"
        keys_added = [f"`adata.{field}['{prefix} {lin} corr']`" for lin in lineages]

        logg.info(
            f"Adding `.{P.LIN_DRIVERS}`\n       "
            + "\n       ".join(keys_added)
            + "\n    Finish",
            time=start,
        )

        if return_drivers:
            return drivers
Example #7
0
    def compute_absorption_probabilities(
        self,
        keys: Optional[Sequence[str]] = None,
        check_irreducibility: bool = False,
        solver: str = "gmres",
        use_petsc: bool = True,
        time_to_absorption: Optional[
            Union[
                str,
                Sequence[Union[str, Sequence[str]]],
                Dict[Union[str, Sequence[str]], str],
            ]
        ] = None,
        n_jobs: Optional[int] = None,
        backend: str = "loky",
        show_progress_bar: bool = True,
        tol: float = 1e-6,
        preconditioner: Optional[str] = None,
    ) -> None:
        """
        Compute absorption probabilities of a Markov chain.

        For each cell, this computes the probability of it reaching any of the approximate recurrent classes defined
        by :paramref:`{fs}`.

        Parameters
        ----------
        keys
            Keys defining the recurrent classes.
        check_irreducibility:
            Check whether the transition matrix is irreducible.
        solver
            Solver to use for the linear problem. Options are `'direct', 'gmres', 'lgmres', 'bicgstab' or 'gcrotmk'`
            when ``use_petsc=False`` or one of :class:`petsc4py.PETSc.KPS.Type` otherwise.

            Information on the :mod:`scipy` iterative solvers can be found in :func:`scipy.sparse.linalg` or for
            :mod:`petsc4py` solver `here <https://www.mcs.anl.gov/petsc/documentation/linearsolvertable.html>`__.
        use_petsc
            Whether to use solvers from :mod:`petsc4py` or :mod:`scipy`. Recommended for large problems.
            If no installation is found, defaults to :func:`scipy.sparse.linalg.gmres`.
        time_to_absorption
            Whether to compute mean time to absorption and its variance to specific absorbing states.

            If a :class:`dict`, can be specified as ``{{'Alpha': 'var', ...}}`` to also compute variance.
            In case when states are a :class:`tuple`, time to absorption will be computed to the subset of these states,
            such as ``[('Alpha', 'Beta'), ...]`` or ``{{('Alpha', 'Beta'): 'mean', ...}}``.
            Can be specified as ``'all'`` to compute it to any absorbing state in ``keys``, which is more efficient
            than listing all absorbing states.

            It might be beneficial to disable the progress bar as ``show_progress_bar=False``, because many linear
            systems are being solved.
        n_jobs
            Number of parallel jobs to use when using an iterative solver. When ``use_petsc=True`` or for
            quickly-solvable problems, we recommend higher number (>=8) of jobs in order to fully saturate the cores.
        backend
            Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options.
        show_progress_bar
            Whether to show progress bar when the solver isn't a direct one.
        tol
            Convergence tolerance for the iterative solver. The default is fine for most cases, only consider
            decreasing this for severely ill-conditioned matrices.
        preconditioner
            Preconditioner to use, only available when ``use_petsc=True``. For available values, see
            `here <https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/PC/PCType.html#PCType>`__ or the values
            of `petsc4py.PETSc.PC.Type`.
            We recommended `'ilu'` preconditioner for badly conditioned problems.

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :paramref:`{abs_prob}` - probabilities of being absorbed into the terminal states.
                - :paramref:`{lat}` - mean times until absorption to subset absorbing states and optionally
                  their variances saved as ``'{{lineage}} mean'`` and ``'{{lineage}} var'``, respectively,
                  for each subset of absorbing states specified in ``time_to_absorption``.
        """

        if self._get(P.TERM) is None:
            raise RuntimeError(_COMP_TERM_STATES_MSG)
        if keys is not None:
            keys = sorted(set(keys))

        start = logg.info("Computing absorption probabilities")

        # get the transition matrix
        t = self.transition_matrix
        if not self.issparse:
            logg.warning(
                "Attempting to solve a potentially large linear system with dense transition matrix"
            )

        # process the current annotations according to `keys`
        terminal_states_, colors_ = _process_series(
            series=self._get(P.TERM), keys=keys, colors=self._get(A.TERM_COLORS)
        )
        # warn in case only one state is left
        keys = list(terminal_states_.cat.categories)
        if len(keys) == 1:
            logg.warning(
                "There is only 1 recurrent class, all cells will have probability 1 of going there"
            )

        lin_abs_times = {}
        if time_to_absorption is not None:
            if isinstance(time_to_absorption, (str, tuple)):
                time_to_absorption = [time_to_absorption]
            if not isinstance(time_to_absorption, dict):
                time_to_absorption = {ln: "mean" for ln in time_to_absorption}
            for ln, moment in time_to_absorption.items():
                if moment not in ("mean", "var"):
                    raise ValueError(
                        f"Moment must be either `'mean'` or `'var'`, found `{moment!r}` for `{ln!r}`."
                    )

                seen = set()
                if isinstance(ln, str):
                    ln = tuple(keys) if ln == "all" else (ln,)
                sorted_ln = tuple(sorted(ln))  # preserve the user order

                if sorted_ln not in seen:
                    seen.add(sorted_ln)
                    for lin in ln:
                        if lin not in keys:
                            raise ValueError(
                                f"Invalid absorbing state `{lin!r}` in `{ln}`. "
                                f"Valid options are `{list(terminal_states_.cat.categories)}`."
                            )
                    lin_abs_times[tuple(ln)] = moment

            # define the dimensions of this problem
        n_cells = t.shape[0]
        n_macrostates = len(terminal_states_.cat.categories)

        # get indices corresponding to recurrent and transient states
        rec_indices, trans_indices, lookup_dict = _get_cat_and_null_indices(
            terminal_states_
        )
        if not len(trans_indices):
            raise RuntimeError("Cannot proceed - Markov chain is irreducible.")

        # create Q (restriction transient-transient), S (restriction transient-recurrent)
        q = t[trans_indices, :][:, trans_indices]
        s = t[trans_indices, :][:, rec_indices]

        # check for irreducibility
        if check_irreducibility:
            if self.is_irreducible is None:
                self._is_irreducible = _irreducible(self.transition_matrix)
            else:
                if not self.is_irreducible:
                    logg.warning("Transition matrix is not irreducible")
                else:
                    logg.debug("Transition matrix is irreducible")

        logg.debug(f"Found `{n_cells}` cells and `{s.shape[1]}` absorbing states")

        # solve the linear system of equations
        mat_x = _solve_lin_system(
            q,
            s,
            solver=solver,
            use_petsc=use_petsc,
            n_jobs=n_jobs,
            backend=backend,
            tol=tol,
            use_eye=True,
            show_progress_bar=show_progress_bar,
            preconditioner=preconditioner,
        )

        if time_to_absorption is not None:
            abs_time_means = _calculate_lineage_absorption_time_means(
                q,
                t[trans_indices, :][:, rec_indices],
                trans_indices,
                n=t.shape[0],
                ixs=lookup_dict,
                lineages=lin_abs_times,
                solver=solver,
                use_petsc=use_petsc,
                n_jobs=n_jobs,
                backend=backend,
                tol=tol,
                show_progress_bar=show_progress_bar,
                preconditioner=preconditioner,
            )
            abs_time_means.index = self.adata.obs_names
        else:
            abs_time_means = None

        # take individual solutions and piece them together to get absorption probabilities towards the classes
        macro_ix_helper = np.cumsum(
            [0] + [len(indices) for indices in lookup_dict.values()]
        )
        _abs_classes = np.concatenate(
            [
                mat_x[:, np.arange(a, b)].sum(1)[:, None]
                for a, b in _pairwise(macro_ix_helper)
            ],
            axis=1,
        )

        # for recurrent states, set their self-absorption probability to one
        abs_classes = np.zeros((len(self), n_macrostates))
        rec_classes_full = {
            cl: np.where(terminal_states_ == cl)[0]
            for cl in terminal_states_.cat.categories
        }
        for col, cl_indices in enumerate(rec_classes_full.values()):
            abs_classes[trans_indices, col] = _abs_classes[:, col]
            abs_classes[cl_indices, col] = 1

        self._set(
            A.ABS_PROBS,
            Lineage(
                abs_classes,
                names=terminal_states_.cat.categories,
                colors=colors_,
            ),
        )

        extra_msg = ""
        if abs_time_means is not None:
            self._set(A.LIN_ABS_TIMES, abs_time_means)
            extra_msg = f"       `.{P.LIN_ABS_TIMES}`\n"

        self._write_absorption_probabilities(time=start, extra_msg=extra_msg)
Example #8
0
    def compute_lineage_drivers(
        self,
        lineages: Optional[Union[str, Sequence]] = None,
        cluster_key: Optional[str] = None,
        clusters: Optional[Union[str, Sequence]] = None,
        layer: str = "X",
        use_raw: bool = False,
        return_drivers: bool = False,
    ) -> Optional[pd.DataFrame]:
        """
        Compute driver genes per lineage.

        Correlates gene expression with lineage probabilities, for a given lineage and set of clusters.
        Often, it makes sense to restrict this to a set of clusters which are relevant for the specified lineages.

        Parameters
        ----------
        lineages
            Either a set of lineage names from :paramref:`absorption_probabilities` `.names` or `None`,
            in which case all lineages are considered.
        cluster_key
            Key from :paramref:`adata` ``.obs`` to obtain cluster annotations. These are considered for ``clusters``.
        clusters
            Restrict the correlations to these clusters.
        layer
            Key from :paramref:`adata` ``.layers``.
        use_raw
            Whether or not to use :paramref:`adata` ``.raw`` to correlate gene expression.
            If using a layer other than ``.X``, this must be set to `False`.
        return_drivers
            Whether to return the lineage drivers as :class:`pandas.DataFrame`.

        Returns
        -------
        :class:`pandas.DataFrame` or :obj:`None`
            Updates :paramref:`adata` ``.var`` or :paramref:`adata` ``.raw.var``, depending on ``use_raw``
            with lineage drivers saved as columns of the form ``{{direction}} {{lineages}}``.
            Also updates the following fields:
                - :paramref:`{lin_drivers}` - the driver genes for each lineage.
            If ``return_drivers=True``, returns the lineage drivers as :class:`pandas.DataFrame`.
        """

        # check that lineage probs have been computed
        abs_probs = self._get(P.ABS_PROBS)
        prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD

        if abs_probs is None:
            raise RuntimeError(
                "Compute absorption probabilities first as `.compute_absorption_probabilities()`."
            )
        elif abs_probs.shape[1] == 1:
            logg.warning(
                "There is only 1 lineage present. Using the stationary distribution instead"
            )
            abs_probs = Lineage(
                self._get(P.FIN_PROBS).values,
                names=abs_probs.names,
                colors=abs_probs.colors,
            )

        # check all lin_keys exist in self.lin_names
        if isinstance(lineages, str):
            lineages = [lineages]
        if lineages is not None:
            _ = abs_probs[lineages]
        else:
            lineages = abs_probs.names

        if not len(lineages):
            raise ValueError("No lineages have been selected.")

        # use `cluster_key` and clusters to subset the data
        if clusters is not None:
            if cluster_key not in self.adata.obs.keys():
                raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.")
            if isinstance(clusters, str):
                clusters = [clusters]

            all_clusters = np.array(self.adata.obs[cluster_key].cat.categories)
            cluster_mask = np.array([name not in all_clusters for name in clusters])

            if any(cluster_mask):
                raise KeyError(
                    f"Clusters `{list(np.array(clusters)[cluster_mask])}` not found in "
                    f"`adata.obs[{cluster_key!r}]`."
                )
            subset_mask = np.in1d(self.adata.obs[cluster_key], clusters)
            adata_comp = self.adata[subset_mask]
            lin_probs = abs_probs[subset_mask, :]
        else:
            adata_comp = self.adata
            lin_probs = abs_probs

        # check that the layer exists, and that use raw is only used with layer X
        if layer != "X":
            if layer not in self.adata.layers:
                raise KeyError(f"Layer `{layer!r}` not found in `adata.layers`.")
            if use_raw:
                raise ValueError("For `use_raw=True`, layer must be 'X'.")
            data = adata_comp.layers[layer]
            var_names = adata_comp.var_names
        else:
            if use_raw and self.adata.raw is None:
                logg.warning("No raw attribute set. Using `.X` instead")
                use_raw = False
            data = adata_comp.raw.X if use_raw else adata_comp.X
            var_names = adata_comp.raw.var_names if use_raw else adata_comp.var_names

        start = logg.info(
            f"Computing correlations for lineages `{lineages}` restricted to clusters `{clusters}` in "
            f"layer `{layer}` with `use_raw={use_raw}`"
        )

        lin_corrs = {}
        for lineage in lineages:
            key = f"{prefix} {lineage}"

            correlations = _vec_mat_corr(data, lin_probs[:, lineage].X.squeeze())
            lin_corrs[lineage] = correlations

            if use_raw:
                self.adata.raw.var[key] = correlations
            else:
                self.adata.var[key] = correlations

        drivers = pd.DataFrame(lin_corrs, index=var_names)
        self._set(A.LIN_DRIVERS, drivers)

        field = "raw.var" if use_raw else "var"
        keys_added = [f"`adata.{field}['{prefix} {lin}']`" for lin in lineages]

        logg.info(
            f"Adding `.{P.LIN_DRIVERS}`\n       "
            + "\n       ".join(keys_added)
            + "\n    Finish",
            time=start,
        )

        if return_drivers:
            return drivers