def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str): self._set_or_debug(obsm_key, self.adata.obsm, attr) names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns) colors = self._set_or_debug(_colors(self._term_key), self.adata.uns) probs = self._get(attr) if probs is not None: if len(names) != probs.shape[1]: if isinstance(probs, Lineage): names = probs.names else: logg.warning( f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new names" ) names = [f"Lineage {i}" for i in range(probs.shape[1])] if len(colors) != probs.shape[1] or not all( map(lambda c: isinstance(c, str) and is_color_like(c), colors) ): if isinstance(probs, Lineage): colors = probs.colors else: logg.warning( f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new colors" ) colors = _create_categorical_colors(probs.shape[1]) self._set(attr, Lineage(probs, names=names, colors=colors)) self.adata.obsm[obsm_key] = self._get(attr) self.adata.uns[_lin_names(self._term_key)] = names self.adata.uns[_colors(self._term_key)] = colors
def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str): self._set_or_debug(obsm_key, self.adata.obsm, attr) names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns) colors = self._set_or_debug(_colors(self._term_key), self.adata.uns) # choosing this instead of property because GPCCA doesn't have property for FIN_ABS_PROBS probs = self._get(attr) if probs is not None: if len(names) != probs.shape[1]: logg.debug( f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new names") names = [f"Lineage {i}" for i in range(probs.shape[1])] if len(colors) != probs.shape[1] or not all( map(lambda c: isinstance(c, str) and is_color_like(c), colors)): logg.debug( f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new colors") colors = _create_categorical_colors(probs.shape[1]) self._set(attr, Lineage(probs, names=names, colors=colors)) self.adata.obsm[obsm_key] = self._get(attr) self.adata.uns[_lin_names(self._term_key)] = names self.adata.uns[_colors(self._term_key)] = colors
def maybe_create_lineage( direction: Union[str, Direction], pretty_name: Optional[str] = None ): if isinstance(direction, Direction): lin_key = str( AbsProbKey.FORWARD if direction == Direction.FORWARD else AbsProbKey.BACKWARD ) else: lin_key = direction pretty_name = "" if pretty_name is None else (pretty_name + " ") names_key, colors_key = _lin_names(lin_key), _colors(lin_key) if lin_key in adata.obsm.keys(): n_cells, n_lineages = adata.obsm[lin_key].shape logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`") if names_key not in adata.uns.keys(): logg.warning( f" Lineage names not found in `adata.uns[{names_key!r}]`, creating new names" ) names = [f"Lineage {i}" for i in range(n_lineages)] elif len(adata.uns[names_key]) != n_lineages: logg.warning( f" Lineage names are don't have the required length ({n_lineages}), creating new names" ) names = [f"Lineage {i}" for i in range(n_lineages)] else: logg.info(" Successfully loaded names") names = adata.uns[names_key] if colors_key not in adata.uns.keys(): logg.warning( f" Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors" ) colors = _create_categorical_colors(n_lineages) elif len(adata.uns[colors_key]) != n_lineages or not all( map(lambda c: is_color_like(c), adata.uns[colors_key]) ): logg.warning( f" Lineage colors don't have the required length ({n_lineages}) " f"or are not color-like, creating new colors" ) colors = _create_categorical_colors(n_lineages) else: logg.info(" Successfully loaded colors") colors = adata.uns[colors_key] adata.obsm[lin_key] = Lineage( adata.obsm[lin_key], names=names, colors=colors ) adata.uns[colors_key] = colors adata.uns[names_key] = names else: logg.debug( f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`" )
def _compute_one_macrostate( self, n_cells: int, cluster_key: Optional[str], en_cutoff: Optional[float], p_thresh: float, ) -> None: start = logg.warning( "For 1 macrostate, stationary distribution is computed") eig = self._get(P.EIG) if (eig is not None and "stationary_dist" in eig and eig["params"]["which"] == "LR"): stationary_dist = eig["stationary_dist"] else: self.compute_eigendecomposition(only_evals=False, which="LR") stationary_dist = self._get(P.EIG)["stationary_dist"] self._set_macrostates( memberships=stationary_dist[:, None], n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) self._set( A.MACRO_MEMBER, Lineage( stationary_dist, names=list(self._get(A.MACRO).cat.categories), colors=self._get(A.MACRO_COLORS), ), ) # reset all the things for key in ( A.ABS_PROBS, A.PRIME_DEG, A.SCHUR, A.SCHUR_MAT, A.COARSE_T, A.COARSE_STAT_D, A.COARSE_STAT_D, ): self._set(key.s, None) logg.info( f"Adding `.{P.MACRO_MEMBER}`\n `.{P.MACRO}`\n Finish", time=start, )
def _set_macrostates( self, memberships: np.ndarray, n_cells: Optional[int] = 30, cluster_key: str = "clusters", en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, check_row_sums: bool = True, ) -> None: """ Map fuzzy clustering to pre-computed annotations to get names and colors. Given the fuzzy clustering, we would like to select the most likely cells from each state and use these to give each state a name and a color by comparing with pre-computed, categorical cluster annotations. Parameters ---------- memberships Fuzzy clustering. %(n_cells)s cluster_key Key from :attr:`adata` ``.obs`` to get reference cluster annotations. en_cutoff Threshold to decide when we we want to warn the user about an uncertain name mapping. This happens when one fuzzy state overlaps with several reference clusters, and the most likely cells are distributed almost evenly across the reference clusters. p_thresh Only used to detect cell cycle stages. These have to be present in :attr:`adata` ``.obs`` as `'G2M_score'` and `'S_score'`. check_row_sums Check whether rows in `memberships` sum to `1`. Returns ------- None Writes a :class:`cellrank.tl.Lineage` object which mapped names and colors. Also writes a categorical :class:`pandas.Series`, where top ``n_cells`` cells represent each fuzzy state. """ if n_cells is None: logg.debug("Setting the macrostates using macrostate assignment") # fmt: off max_assignment = np.argmax(memberships, axis=1) _macro_assignment = pd.Series(index=self.adata.obs_names, data=max_assignment, dtype="category") # sometimes, the assignment can have a missing category and the Lineage creation therefore fails # keep it as ints when `n_cells != None` _macro_assignment = _macro_assignment.cat.set_categories( list(range(memberships.shape[1]))) macrostates = _macro_assignment.astype(str).astype( "category").copy() not_enough_cells = [] # fmt: on else: logg.debug("Setting the macrostates using macrostates memberships") # select the most likely cells from each macrostate macrostates, not_enough_cells = self._create_states( memberships, n_cells=n_cells, check_row_sums=check_row_sums, return_not_enough_cells=True, ) not_enough_cells = not_enough_cells.astype("str") # _set_categorical_labels creates the names, we still need to remap the group names orig_cats = macrostates.cat.categories self._set_categorical_labels( attr_key=A.MACRO.v, color_key=A.MACRO_COLORS.v, pretty_attr_key=P.MACRO.v, add_to_existing_error_msg= "Compute macrostates first as `.compute_macrostates()`.", categories=macrostates, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, ) name_mapper = dict(zip(orig_cats, self._get(P.MACRO).cat.categories)) _print_insufficient_number_of_cells( [name_mapper.get(n, n) for n in not_enough_cells], n_cells) logg.debug( "Setting macrostates memberships based on GPCCA membership vectors" ) self._set( A.MACRO_MEMBER, Lineage( memberships, names=list(macrostates.cat.categories), colors=self._get(A.MACRO_COLORS), ), )
def compute_lineage_drivers( self, lineages: Optional[Union[str, Sequence]] = None, method: str = TestMethod.FISCHER.s, cluster_key: Optional[str] = None, clusters: Optional[Union[str, Sequence]] = None, layer: str = "X", use_raw: bool = False, confidence_level: float = 0.95, n_perms: int = 1000, seed: Optional[int] = None, return_drivers: bool = True, **kwargs, ) -> Optional[pd.DataFrame]: """ Compute driver genes per lineage. Correlates gene expression with lineage probabilities, for a given lineage and set of clusters. Often, it makes sense to restrict this to a set of clusters which are relevant for the specified lineages. Parameters ---------- lineages Either a set of lineage names from :paramref:`absorption_probabilities` `.names` or `None`, in which case all lineages are considered. method Mode to use when calculating p-values and confidence intervals. Can be one of: - {tm.FISCHER.s!r} - use Fischer transformation [Fischer21]_. - {tm.PERM_TEST.s!r} - use permutation test. cluster_key Key from :paramref:`adata` ``.obs`` to obtain cluster annotations. These are considered for ``clusters``. clusters Restrict the correlations to these clusters. layer Key from :paramref:`adata` ``.layers``. use_raw Whether or not to use :paramref:`adata` ``.raw`` to correlate gene expression. If using a layer other than ``.X``, this must be set to `False`. confidence_level Confidence level for the confidence interval calculation. Must be in `[0, 1]`. n_perms Number of permutations to use when ``method={tm.PERM_TEST.s!r}``. seed Random seed when ``method={tm.PERM_TEST.s!r}``. return_drivers Whether to return the drivers. This also contains the lower and upper ``confidence_level`` confidence interval bounds. %(parallel)s Returns ------- %(correlation_test.returns)s Only if ``return_drivers=True``. None Updates :paramref:`adata` ``.var`` or :paramref:`adata` ``.raw.var``, depending ``use_raw`` with: - ``'{{direction}} {{lineage}} corr'`` - the potential lineage drivers. - ``'{{direction}} {{lineage}} qval'`` - the corrected p-values. Updates the following fields: - :paramref:`{lin_drivers}` - same as the returned values. References ---------- .. [Fischer21] Fisher, R. A. (1921), *On the “probable error” of a coefficient of correlation deduced from a small sample.*, `Metron 1 3–32 <http://hdl.handle.net/2440/15169>`__. """ # check that lineage probs have been computed method = TestMethod(method) abs_probs = self._get(P.ABS_PROBS) prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD if abs_probs is None: raise RuntimeError( "Compute absorption probabilities first as `.compute_absorption_probabilities()`." ) elif abs_probs.shape[1] == 1: logg.warning( "There is only 1 lineage present. Using the stationary distribution instead" ) abs_probs = Lineage( self._get(P.TERM_PROBS).values, names=abs_probs.names, colors=abs_probs.colors, ) # check all lin_keys exist in self.lin_names if isinstance(lineages, str): lineages = [lineages] if lineages is not None: _ = abs_probs[lineages] else: lineages = abs_probs.names if not len(lineages): raise ValueError("No lineages have been selected.") # use `cluster_key` and clusters to subset the data if clusters is not None: if cluster_key not in self.adata.obs.keys(): raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") if isinstance(clusters, str): clusters = [clusters] all_clusters = np.array(self.adata.obs[cluster_key].cat.categories) cluster_mask = np.array([name not in all_clusters for name in clusters]) if any(cluster_mask): raise KeyError( f"Clusters `{list(np.array(clusters)[cluster_mask])}` not found in " f"`adata.obs[{cluster_key!r}]`." ) subset_mask = np.in1d(self.adata.obs[cluster_key], clusters) adata_comp = self.adata[subset_mask] lin_probs = abs_probs[subset_mask, :] else: adata_comp = self.adata lin_probs = abs_probs # check that the layer exists, and that use raw is only used with layer X if layer != "X": if layer not in self.adata.layers: raise KeyError(f"Layer `{layer!r}` not found in `adata.layers`.") if use_raw: raise ValueError("For `use_raw=True`, layer must be 'X'.") data = adata_comp.layers[layer] var_names = adata_comp.var_names else: if use_raw and self.adata.raw is None: logg.warning("No raw attribute set. Using `.X` instead") use_raw = False data = adata_comp.raw.X if use_raw else adata_comp.X var_names = adata_comp.raw.var_names if use_raw else adata_comp.var_names start = logg.info( f"Computing correlations for lineages `{lineages}` restricted to clusters `{clusters}` in " f"layer `{layer}` with `use_raw={use_raw}`" ) drivers = _correlation_test( data, lin_probs[lineages], gene_names=var_names, method=method, n_perms=n_perms, seed=seed, confidence_level=confidence_level, **kwargs, ) self._set(A.LIN_DRIVERS, drivers) corrs, qvals = [f"{lin} corr" for lin in lineages], [ f"{lin} qval" for lin in lineages ] if use_raw: self.adata.raw.var[[f"{prefix} {col}" for col in corrs]] = drivers[corrs] self.adata.raw.var[[f"{prefix} {col}" for col in qvals]] = drivers[qvals] else: self.adata.var[[f"{prefix} {col}" for col in corrs]] = drivers[corrs] self.adata.var[[f"{prefix} {col}" for col in qvals]] = drivers[qvals] field = "raw.var" if use_raw else "var" keys_added = [f"`adata.{field}['{prefix} {lin} corr']`" for lin in lineages] logg.info( f"Adding `.{P.LIN_DRIVERS}`\n " + "\n ".join(keys_added) + "\n Finish", time=start, ) if return_drivers: return drivers
def compute_absorption_probabilities( self, keys: Optional[Sequence[str]] = None, check_irreducibility: bool = False, solver: str = "gmres", use_petsc: bool = True, time_to_absorption: Optional[ Union[ str, Sequence[Union[str, Sequence[str]]], Dict[Union[str, Sequence[str]], str], ] ] = None, n_jobs: Optional[int] = None, backend: str = "loky", show_progress_bar: bool = True, tol: float = 1e-6, preconditioner: Optional[str] = None, ) -> None: """ Compute absorption probabilities of a Markov chain. For each cell, this computes the probability of it reaching any of the approximate recurrent classes defined by :paramref:`{fs}`. Parameters ---------- keys Keys defining the recurrent classes. check_irreducibility: Check whether the transition matrix is irreducible. solver Solver to use for the linear problem. Options are `'direct', 'gmres', 'lgmres', 'bicgstab' or 'gcrotmk'` when ``use_petsc=False`` or one of :class:`petsc4py.PETSc.KPS.Type` otherwise. Information on the :mod:`scipy` iterative solvers can be found in :func:`scipy.sparse.linalg` or for :mod:`petsc4py` solver `here <https://www.mcs.anl.gov/petsc/documentation/linearsolvertable.html>`__. use_petsc Whether to use solvers from :mod:`petsc4py` or :mod:`scipy`. Recommended for large problems. If no installation is found, defaults to :func:`scipy.sparse.linalg.gmres`. time_to_absorption Whether to compute mean time to absorption and its variance to specific absorbing states. If a :class:`dict`, can be specified as ``{{'Alpha': 'var', ...}}`` to also compute variance. In case when states are a :class:`tuple`, time to absorption will be computed to the subset of these states, such as ``[('Alpha', 'Beta'), ...]`` or ``{{('Alpha', 'Beta'): 'mean', ...}}``. Can be specified as ``'all'`` to compute it to any absorbing state in ``keys``, which is more efficient than listing all absorbing states. It might be beneficial to disable the progress bar as ``show_progress_bar=False``, because many linear systems are being solved. n_jobs Number of parallel jobs to use when using an iterative solver. When ``use_petsc=True`` or for quickly-solvable problems, we recommend higher number (>=8) of jobs in order to fully saturate the cores. backend Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. show_progress_bar Whether to show progress bar when the solver isn't a direct one. tol Convergence tolerance for the iterative solver. The default is fine for most cases, only consider decreasing this for severely ill-conditioned matrices. preconditioner Preconditioner to use, only available when ``use_petsc=True``. For available values, see `here <https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/PC/PCType.html#PCType>`__ or the values of `petsc4py.PETSc.PC.Type`. We recommended `'ilu'` preconditioner for badly conditioned problems. Returns ------- None Nothing, but updates the following fields: - :paramref:`{abs_prob}` - probabilities of being absorbed into the terminal states. - :paramref:`{lat}` - mean times until absorption to subset absorbing states and optionally their variances saved as ``'{{lineage}} mean'`` and ``'{{lineage}} var'``, respectively, for each subset of absorbing states specified in ``time_to_absorption``. """ if self._get(P.TERM) is None: raise RuntimeError(_COMP_TERM_STATES_MSG) if keys is not None: keys = sorted(set(keys)) start = logg.info("Computing absorption probabilities") # get the transition matrix t = self.transition_matrix if not self.issparse: logg.warning( "Attempting to solve a potentially large linear system with dense transition matrix" ) # process the current annotations according to `keys` terminal_states_, colors_ = _process_series( series=self._get(P.TERM), keys=keys, colors=self._get(A.TERM_COLORS) ) # warn in case only one state is left keys = list(terminal_states_.cat.categories) if len(keys) == 1: logg.warning( "There is only 1 recurrent class, all cells will have probability 1 of going there" ) lin_abs_times = {} if time_to_absorption is not None: if isinstance(time_to_absorption, (str, tuple)): time_to_absorption = [time_to_absorption] if not isinstance(time_to_absorption, dict): time_to_absorption = {ln: "mean" for ln in time_to_absorption} for ln, moment in time_to_absorption.items(): if moment not in ("mean", "var"): raise ValueError( f"Moment must be either `'mean'` or `'var'`, found `{moment!r}` for `{ln!r}`." ) seen = set() if isinstance(ln, str): ln = tuple(keys) if ln == "all" else (ln,) sorted_ln = tuple(sorted(ln)) # preserve the user order if sorted_ln not in seen: seen.add(sorted_ln) for lin in ln: if lin not in keys: raise ValueError( f"Invalid absorbing state `{lin!r}` in `{ln}`. " f"Valid options are `{list(terminal_states_.cat.categories)}`." ) lin_abs_times[tuple(ln)] = moment # define the dimensions of this problem n_cells = t.shape[0] n_macrostates = len(terminal_states_.cat.categories) # get indices corresponding to recurrent and transient states rec_indices, trans_indices, lookup_dict = _get_cat_and_null_indices( terminal_states_ ) if not len(trans_indices): raise RuntimeError("Cannot proceed - Markov chain is irreducible.") # create Q (restriction transient-transient), S (restriction transient-recurrent) q = t[trans_indices, :][:, trans_indices] s = t[trans_indices, :][:, rec_indices] # check for irreducibility if check_irreducibility: if self.is_irreducible is None: self._is_irreducible = _irreducible(self.transition_matrix) else: if not self.is_irreducible: logg.warning("Transition matrix is not irreducible") else: logg.debug("Transition matrix is irreducible") logg.debug(f"Found `{n_cells}` cells and `{s.shape[1]}` absorbing states") # solve the linear system of equations mat_x = _solve_lin_system( q, s, solver=solver, use_petsc=use_petsc, n_jobs=n_jobs, backend=backend, tol=tol, use_eye=True, show_progress_bar=show_progress_bar, preconditioner=preconditioner, ) if time_to_absorption is not None: abs_time_means = _calculate_lineage_absorption_time_means( q, t[trans_indices, :][:, rec_indices], trans_indices, n=t.shape[0], ixs=lookup_dict, lineages=lin_abs_times, solver=solver, use_petsc=use_petsc, n_jobs=n_jobs, backend=backend, tol=tol, show_progress_bar=show_progress_bar, preconditioner=preconditioner, ) abs_time_means.index = self.adata.obs_names else: abs_time_means = None # take individual solutions and piece them together to get absorption probabilities towards the classes macro_ix_helper = np.cumsum( [0] + [len(indices) for indices in lookup_dict.values()] ) _abs_classes = np.concatenate( [ mat_x[:, np.arange(a, b)].sum(1)[:, None] for a, b in _pairwise(macro_ix_helper) ], axis=1, ) # for recurrent states, set their self-absorption probability to one abs_classes = np.zeros((len(self), n_macrostates)) rec_classes_full = { cl: np.where(terminal_states_ == cl)[0] for cl in terminal_states_.cat.categories } for col, cl_indices in enumerate(rec_classes_full.values()): abs_classes[trans_indices, col] = _abs_classes[:, col] abs_classes[cl_indices, col] = 1 self._set( A.ABS_PROBS, Lineage( abs_classes, names=terminal_states_.cat.categories, colors=colors_, ), ) extra_msg = "" if abs_time_means is not None: self._set(A.LIN_ABS_TIMES, abs_time_means) extra_msg = f" `.{P.LIN_ABS_TIMES}`\n" self._write_absorption_probabilities(time=start, extra_msg=extra_msg)
def compute_lineage_drivers( self, lineages: Optional[Union[str, Sequence]] = None, cluster_key: Optional[str] = None, clusters: Optional[Union[str, Sequence]] = None, layer: str = "X", use_raw: bool = False, return_drivers: bool = False, ) -> Optional[pd.DataFrame]: """ Compute driver genes per lineage. Correlates gene expression with lineage probabilities, for a given lineage and set of clusters. Often, it makes sense to restrict this to a set of clusters which are relevant for the specified lineages. Parameters ---------- lineages Either a set of lineage names from :paramref:`absorption_probabilities` `.names` or `None`, in which case all lineages are considered. cluster_key Key from :paramref:`adata` ``.obs`` to obtain cluster annotations. These are considered for ``clusters``. clusters Restrict the correlations to these clusters. layer Key from :paramref:`adata` ``.layers``. use_raw Whether or not to use :paramref:`adata` ``.raw`` to correlate gene expression. If using a layer other than ``.X``, this must be set to `False`. return_drivers Whether to return the lineage drivers as :class:`pandas.DataFrame`. Returns ------- :class:`pandas.DataFrame` or :obj:`None` Updates :paramref:`adata` ``.var`` or :paramref:`adata` ``.raw.var``, depending on ``use_raw`` with lineage drivers saved as columns of the form ``{{direction}} {{lineages}}``. Also updates the following fields: - :paramref:`{lin_drivers}` - the driver genes for each lineage. If ``return_drivers=True``, returns the lineage drivers as :class:`pandas.DataFrame`. """ # check that lineage probs have been computed abs_probs = self._get(P.ABS_PROBS) prefix = DirPrefix.BACKWARD if self.kernel.backward else DirPrefix.FORWARD if abs_probs is None: raise RuntimeError( "Compute absorption probabilities first as `.compute_absorption_probabilities()`." ) elif abs_probs.shape[1] == 1: logg.warning( "There is only 1 lineage present. Using the stationary distribution instead" ) abs_probs = Lineage( self._get(P.FIN_PROBS).values, names=abs_probs.names, colors=abs_probs.colors, ) # check all lin_keys exist in self.lin_names if isinstance(lineages, str): lineages = [lineages] if lineages is not None: _ = abs_probs[lineages] else: lineages = abs_probs.names if not len(lineages): raise ValueError("No lineages have been selected.") # use `cluster_key` and clusters to subset the data if clusters is not None: if cluster_key not in self.adata.obs.keys(): raise KeyError(f"Key `{cluster_key!r}` not found in `adata.obs`.") if isinstance(clusters, str): clusters = [clusters] all_clusters = np.array(self.adata.obs[cluster_key].cat.categories) cluster_mask = np.array([name not in all_clusters for name in clusters]) if any(cluster_mask): raise KeyError( f"Clusters `{list(np.array(clusters)[cluster_mask])}` not found in " f"`adata.obs[{cluster_key!r}]`." ) subset_mask = np.in1d(self.adata.obs[cluster_key], clusters) adata_comp = self.adata[subset_mask] lin_probs = abs_probs[subset_mask, :] else: adata_comp = self.adata lin_probs = abs_probs # check that the layer exists, and that use raw is only used with layer X if layer != "X": if layer not in self.adata.layers: raise KeyError(f"Layer `{layer!r}` not found in `adata.layers`.") if use_raw: raise ValueError("For `use_raw=True`, layer must be 'X'.") data = adata_comp.layers[layer] var_names = adata_comp.var_names else: if use_raw and self.adata.raw is None: logg.warning("No raw attribute set. Using `.X` instead") use_raw = False data = adata_comp.raw.X if use_raw else adata_comp.X var_names = adata_comp.raw.var_names if use_raw else adata_comp.var_names start = logg.info( f"Computing correlations for lineages `{lineages}` restricted to clusters `{clusters}` in " f"layer `{layer}` with `use_raw={use_raw}`" ) lin_corrs = {} for lineage in lineages: key = f"{prefix} {lineage}" correlations = _vec_mat_corr(data, lin_probs[:, lineage].X.squeeze()) lin_corrs[lineage] = correlations if use_raw: self.adata.raw.var[key] = correlations else: self.adata.var[key] = correlations drivers = pd.DataFrame(lin_corrs, index=var_names) self._set(A.LIN_DRIVERS, drivers) field = "raw.var" if use_raw else "var" keys_added = [f"`adata.{field}['{prefix} {lin}']`" for lin in lineages] logg.info( f"Adding `.{P.LIN_DRIVERS}`\n " + "\n ".join(keys_added) + "\n Finish", time=start, ) if return_drivers: return drivers