def _density_normalize( self, other: Union[np.ndarray, spmatrix]) -> Union[np.ndarray, spmatrix]: """ Density normalization by the underlying KNN graph. Parameters ---------- other: Matrix to normalize. Returns ------- :class:`np.ndarray` or :class:`scipy.sparse.spmatrix` Density normalized transition matrix. """ logg.debug("Density-normalizing the transition matrix") q = np.asarray(self._conn.sum(axis=0)) if not issparse(other): Q = np.diag(1.0 / q) else: Q = spdiags(1.0 / q, 0, other.shape[0], other.shape[0]) return Q @ other @ Q
def write(self, fname: Union[str, Path], ext: Optional[str] = "pickle") -> None: """ Serialize self to a file. Parameters ---------- fname Filename where to save the object. ext Filename extension to use. If `None`, don't append any extension. Returns ------- None Nothing, just writes itself to a file using :mod:`pickle`. """ fname = str(fname) if ext is not None: if not ext.startswith("."): ext = "." + ext if not fname.endswith(ext): fname += ext logg.debug(f"Writing to `{fname}`") with open(fname, "wb") as fout: pickle.dump(self, fout)
def __init__( self, obj: Union[AnnData, np.ndarray, spmatrix, KernelExpression], key: Optional[str] = None, obsp_key: Optional[str] = None, write_to_adata: bool = True, ): if isinstance(obj, KernelExpression): self._kernel = obj elif isinstance(obj, (np.ndarray, spmatrix)): self._kernel = PrecomputedKernel(obj) elif isinstance(obj, AnnData): if obsp_key is None: raise ValueError( "Please specify `obsp_key=...` when supplying an AnnData object." ) elif obsp_key not in obj.obsp.keys(): raise KeyError(f"Key `{obsp_key!r}` not found in `adata.obsp`.") self._kernel = PrecomputedKernel(obj.obsp[obsp_key], adata=obj) else: raise TypeError( f"Expected an object of type `KernelExpression`, `numpy.ndarray`, `scipy.sparse.spmatrix` " f"or `anndata.AnnData`, got `{type(obj).__name__!r}`." ) if self.kernel.transition_matrix is None: logg.debug("Computing transition matrix using default parameters") self.kernel.compute_transition_matrix() if write_to_adata: self.kernel.write_to_adata(key=key)
def maybe_sanity_check(callbacks: Dict[str, Dict[str, Callable]]) -> None: if not perform_sanity_check: return from sklearn.svm import SVR logg.debug("Performing callback sanity checks") for gene in callbacks.keys(): for lineage, cb in callbacks[gene].items(): # create the model here because the callback can search the attribute dummy_model = SKLearnModel(adata, model=SVR()) try: model = cb(dummy_model, gene=gene, lineage=lineage, **kwargs) assert model is dummy_model, ( "Creation of new models is not allowed. " "Ensure that callback returns the same model." ) assert ( model.prepared ), "Model is not prepared. Ensure that callback calls `.prepare()`." assert ( model._gene == gene ), f"Callback modified the gene from `{gene!r}` to `{model._gene!r}`." assert ( model._lineage == lineage ), f"Callback modified the lineage from `{lineage!r}` to `{model._lineage!r}`." except Exception as e: raise RuntimeError( f"Callback validation failed for gene `{gene!r}` and lineage `{lineage!r}`." ) from e
def _density_normalize( self, other: Union[np.ndarray, spmatrix]) -> Union[np.ndarray, spmatrix]: """ Density normalization by the underlying KNN graph. Parameters ---------- other Matrix to normalize. Returns ------- :class:`np.ndarray` or :class:`scipy.sparse.spmatrix` Density normalized transition matrix. Raises ------ ValueError If KNN connectivities are not set. """ if self._conn is None: raise ValueError( "Unable to density normalize the transition matrix " "because KNN connectivities are not set.") logg.debug("Density-normalizing the transition matrix") q = np.asarray(self._conn.sum(axis=0)) Q = spdiags(1.0 / q, 0, other.shape[0], other.shape[0]) return Q @ other @ Q
def _read_graph_data(adata: AnnData, key: str) -> Union[np.ndarray, spmatrix]: """ Read graph data from :mod:`anndata`. :module`AnnData` >=0.7 stores `(n_obs x n_obs)` matrices in `.obsp` rather than `.uns`. This is for backward compatibility. Parameters ---------- adata Annotated data object. key Key from either ``adata.uns`` or ``adata.obsp``. Returns ------- :class:`numpy.ndarray` or :class:`scipy.sparse.spmatrix` The graph data. """ if hasattr(adata, "obsp") and adata.obsp is not None and key in adata.obsp.keys(): logg.debug(f"Reading key `{key!r}` from `adata.obsp`") return adata.obsp[key] if key in adata.uns.keys(): logg.debug(f"Reading key `{key!r}` from `adata.uns`") return adata.uns[key] raise KeyError(f"Unable to find key `{key!r}` in `adata.obsp` or `adata.uns`.")
def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str): self._set_or_debug(obsm_key, self.adata.obsm, attr) names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns) colors = self._set_or_debug(_colors(self._term_key), self.adata.uns) # choosing this instead of property because GPCCA doesn't have property for FIN_ABS_PROBS probs = self._get(attr) if probs is not None: if len(names) != probs.shape[1]: logg.debug( f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new names") names = [f"Lineage {i}" for i in range(probs.shape[1])] if len(colors) != probs.shape[1] or not all( map(lambda c: isinstance(c, str) and is_color_like(c), colors)): logg.debug( f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new colors") colors = _create_categorical_colors(probs.shape[1]) self._set(attr, Lineage(probs, names=names, colors=colors)) self.adata.obsm[obsm_key] = self._get(attr) self.adata.uns[_lin_names(self._term_key)] = names self.adata.uns[_colors(self._term_key)] = colors
def compute_transition_matrix(self, *args, **kwargs) -> "SimpleNaryExpression": """Compute and combine the transition matrices.""" # must be done before, because the underlying expression don't have to be normed if isinstance(self, KernelSimpleAdd): self._maybe_recalculate_constants(Constant) elif isinstance(self, KernelAdaptiveAdd): self._maybe_recalculate_constants(ConstantMatrix) for kexpr in self: if kexpr._transition_matrix is None: if isinstance(kexpr, Kernel): raise RuntimeError( f"Kernel `{kexpr}` is uninitialized. " f"Compute its transition matrix first as `.compute_transition_matrix()`." ) kexpr.compute_transition_matrix() elif isinstance(kexpr, Kernel): logg.debug(_LOG_USING_CACHE) self.transition_matrix = csr_matrix( self._fn([kexpr.transition_matrix for kexpr in self])) # only the top level expression and kernels will have condition number computed if self._parent is None: self._maybe_compute_cond_num() return self
def maybe_create_lineage( direction: Union[str, Direction], pretty_name: Optional[str] = None ): if isinstance(direction, Direction): lin_key = str( AbsProbKey.FORWARD if direction == Direction.FORWARD else AbsProbKey.BACKWARD ) else: lin_key = direction pretty_name = "" if pretty_name is None else (pretty_name + " ") names_key, colors_key = _lin_names(lin_key), _colors(lin_key) if lin_key in adata.obsm.keys(): n_cells, n_lineages = adata.obsm[lin_key].shape logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`") if names_key not in adata.uns.keys(): logg.warning( f" Lineage names not found in `adata.uns[{names_key!r}]`, creating new names" ) names = [f"Lineage {i}" for i in range(n_lineages)] elif len(adata.uns[names_key]) != n_lineages: logg.warning( f" Lineage names are don't have the required length ({n_lineages}), creating new names" ) names = [f"Lineage {i}" for i in range(n_lineages)] else: logg.info(" Successfully loaded names") names = adata.uns[names_key] if colors_key not in adata.uns.keys(): logg.warning( f" Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors" ) colors = _create_categorical_colors(n_lineages) elif len(adata.uns[colors_key]) != n_lineages or not all( map(lambda c: is_color_like(c), adata.uns[colors_key]) ): logg.warning( f" Lineage colors don't have the required length ({n_lineages}) " f"or are not color-like, creating new colors" ) colors = _create_categorical_colors(n_lineages) else: logg.info(" Successfully loaded colors") colors = adata.uns[colors_key] adata.obsm[lin_key] = Lineage( adata.obsm[lin_key], names=names, colors=colors ) adata.uns[colors_key] = colors adata.uns[names_key] = names else: logg.debug( f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`" )
def _remove_min_clusters(self, min_flow: float) -> None: logg.debug("Removing clusters with no incoming flow edges") columns = (self._flow.loc[(slice(None), self._cluster), :] > min_flow).any() columns = columns[columns].index if not len(columns): raise ValueError( "After removing clusters with no incoming flow edges, none remain." ) self._flow = self._flow[columns]
def _set_or_debug( self, needle: str, haystack, attr: Optional[Union[str, PrettyEnum]] = None ) -> Optional[Any]: if isinstance(attr, PrettyEnum): attr = attr.s if needle in haystack: if attr is None: return haystack[needle] setattr(self, attr, haystack[needle]) elif attr is not None: logg.debug(f"Unable to set attribute `.{attr}`, skipping")
def _filter_models( models, return_models: bool = False, filter_all_failed: bool = True ) -> Tuple[_return_model_type, _return_model_type, Sequence[str], Sequence[str]]: def is_valid(x: Union[BaseModel, BulkRes]) -> bool: if return_models: assert isinstance( x, BaseModel ), f"Expected `BaseModel`, found `{type(x).__name__!r}`." return bool(x) return (x.x_test is not None and x.y_test is not None and np.all(np.isfinite(x.y_test))) modelmat = pd.DataFrame(models).T modelmask = modelmat.applymap(is_valid) to_keep = modelmask[modelmask.any(axis=1)] to_keep = to_keep.loc[:, to_keep.any(axis=0)].T filtered_models = { gene: { ln: models[gene][ln] for ln in (ln for ln in v.keys() if ( is_valid(models[gene][ln]) if filter_all_failed else True)) } for gene, v in to_keep.to_dict().items() } if not len(filtered_models): if not return_models: raise RuntimeError( "Fitting has failed for all gene/lineage combinations. " "Specify `return_models=True` for more information.") for ms in models.values(): for model in ms.values(): assert isinstance( model, FailedModel ), f"Expected `FailedModel`, found `{type(model).__name__!r}`." model.reraise() if not np.all(modelmask.values): failed_models = modelmat.values[~modelmask.values] logg.warning( f"Unable to fit `{len(failed_models)}` models." + "" if return_models else "Consider specify `return_models=True` for further inspection.") logg.debug("The failed models were:\n`{}`".format("\n".join( f" {m}" for m in failed_models))) # lineages is the max number of lineages return models, filtered_models, tuple(filtered_models.keys()), tuple( to_keep.index)
def _reuse_cache(self, expected_params: Dict[str, Any], *, time: Optional[Any] = None) -> bool: if expected_params == self._params: assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG logg.debug(_LOG_USING_CACHE) logg.info(" Finish", time=time) return True self._params = expected_params return False
def _maybe_compute_cond_num(self): if self._compute_cond_num and self._cond_num is None: logg.debug(f"Computing condition number of `{repr(self)}`") self._cond_num = np.linalg.cond(self._transition_matrix.toarray( ) if issparse(self._transition_matrix) else self._transition_matrix ) if self._cond_num > _cond_num_tolerance: logg.warning( f"`{repr(self)}` may be ill-conditioned, its condition number is `{self._cond_num:.2e}`" ) else: logg.info(f"Condition number is `{self._cond_num:.2e}`")
def _read_from_adata(self, **kwargs): super()._read_from_adata(**kwargs) # check whether velocities have been computed vkey = kwargs.pop("vkey", "velocity") if vkey not in self.adata.layers.keys(): raise KeyError("Compute RNA velocity first as `scv.tl.velocity()`.") # restrict genes to a subset, i.e. velocity genes or user provided list gene_subset = kwargs.pop("gene_subset", None) subset = np.ones(self.adata.n_vars, bool) if gene_subset is not None: var_names_subset = self.adata.var_names.isin(gene_subset) subset &= var_names_subset if len(var_names_subset) > 0 else gene_subset elif f"{vkey}_genes" in self.adata.var.keys(): subset &= np.array(self.adata.var[f"{vkey}_genes"].values, dtype=bool) # chose data representation to use for transcriptomic displacements xkey = kwargs.pop("xkey", "Ms") xkey = xkey if xkey in self.adata.layers.keys() else "spliced" # filter both the velocities and the gene expression profiles to the gene subset. Densify the matrices. X = np.array( self.adata.layers[xkey].A[:, subset] if issparse(self.adata.layers[xkey]) else self.adata.layers[xkey][:, subset] ) V = np.array( self.adata.layers[vkey].A[:, subset] if issparse(self.adata.layers[vkey]) else self.adata.layers[vkey][:, subset] ) # remove genes that have any Nan values (in both X and V) nans = np.isnan(np.sum(V, axis=0)) if np.any(nans): X = X[:, ~nans] V = V[:, ~nans] # check the velocity parameters par_key = f"{vkey}_params" if par_key in self.adata.uns.keys(): velocity_params = self.adata.uns[par_key] else: velocity_params = None logg.debug( f"Unable to load velocity parameters from `adata.uns[{par_key!r}]`" ) # add to self self._velocity = V.astype(np.float64) self._gene_expression = X.astype(np.float64) self._velocity_params = velocity_params
def _get_n_states_from_minchi( self, n_states: Union[Tuple[int, int], List[int], Dict[str, int]]) -> int: if self._gpcca is None: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`." ) if not isinstance(n_states, (dict, tuple, list)): raise TypeError( f"Expected `n_states` to be either `dict`, `tuple` or a `list`, " f"found `{type(n_states).__name__}`.") if len(n_states) != 2: raise ValueError( f"Expected `n_states` to be of size `2`, found `{len(n_states)}`." ) if isinstance(n_states, dict): if "min" not in n_states or "max" not in n_states: raise KeyError( f"Expected the dictionary to have `'min'` and `'max'` keys, " f"found `{tuple(n_states.keys())}`.") minn, maxx = n_states["min"], n_states["max"] else: minn, maxx = n_states if minn > maxx: logg.debug( f"Swapping minimum and maximum because `{minn}` > `{maxx}`") minn, maxx = maxx, minn if minn <= 1: raise ValueError(f"Minimum value must be > `1`, found `{minn}`.") elif minn == 2: logg.warning( "In most cases, 2 clusters will always be optimal. " "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`" ) minn = 3 if minn >= maxx: maxx = minn + 1 logg.debug( f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`" ) logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`") return int( np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn, maxx))])
def _load_dataset_from_url(fpath: Union[os.PathLike, str], url: str, expected_shape: Tuple[int, int], **kwargs) -> AnnData: fpath = str(fpath) if not fpath.endswith(".h5ad"): fpath += ".h5ad" if os.path.isfile(fpath): logg.debug(f"Loading dataset from `{fpath!r}`") else: logg.debug(f"Downloading dataset from `{url!r}` as `{fpath!r}`") dirname, _ = os.path.split(fpath) try: if not os.path.isdir(dirname): logg.debug(f"Creating directory `{dirname!r}`") os.makedirs(dirname, exist_ok=True) except OSError as e: logg.debug(f"Unable to create directory `{dirname!r}`. Reason `{e}`") kwargs.setdefault("sparse", True) kwargs.setdefault("cache", True) adata = read(fpath, backup_url=url, **kwargs) if adata.shape != expected_shape: raise ValueError( f"Expected `anndata.AnnData` object to have shape `{expected_shape}`, found `{adata.shape}`." ) adata.var_names_make_unique() return adata
def _read_from_adata(self, **kwargs): super()._read_from_adata(**kwargs) time_key = kwargs.pop("time_key", "dpt_pseudotime") if time_key not in self.adata.obs.keys(): raise KeyError( f"Could not find time key in `adata.obs[{time_key!r}]`.") self._pseudotime = np.array(self.adata.obs[time_key]).astype(_dtype) if np.any(np.isnan(self._pseudotime)): raise ValueError("Encountered NaN values in pseudotime.") logg.debug("Clipping the pseudotime to 0-1 range") self._pseudotime = np.clip(self._pseudotime, 0, 1)
def __init__( self, transition_matrix: Optional[Union[np.ndarray, spmatrix, str]] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, ): from anndata import AnnData as _AnnData if transition_matrix is None: transition_matrix = _transition( Direction.BACKWARD if backward else Direction.FORWARD) logg.debug( f"Setting transition matrix key to `{transition_matrix!r}`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) transition_matrix = _read_graph_data(adata, transition_matrix) if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__!r}`.") if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError( "Not a valid transition matrix: not all rows sum to 1.") if adata is None: logg.warning("Creating empty `AnnData` object") adata = _AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32)) super().__init__(adata, backward=backward, compute_cond_num=compute_cond_num) self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num()
def test_logfile(self, tmp_path, logging_state): settings.verbosity = Verbosity.hint io = StringIO() settings.logfile = io assert settings.logfile is io assert settings.logpath is None logg.error("test!") assert io.getvalue() == "ERROR: test!\n" p = tmp_path / "test.log" settings.logpath = p assert settings.logpath == p assert settings.logfile.name == str(p) logg.hint("test2") logg.debug("invisible") assert settings.logpath.read_text() == "--> test2\n"
def _ensure_numeric_ordered(adata: AnnData, key: str) -> pd.Series: if key not in adata.obs.keys(): raise KeyError(f"Unable to find data in `adata.obs[{key!r}]`.") exp_time = adata.obs[key].copy() if not is_numeric_dtype(np.asarray(exp_time)): try: exp_time = np.asarray(exp_time).astype(float) except ValueError as e: raise TypeError( f"Unable to convert `adata.obs[{key!r}]` of type `{infer_dtype(adata.obs[key])}` to `float`." ) from e if not is_categorical_dtype(exp_time): logg.debug(f"Converting `adata.obs[{key!r}]` to `categorical`") exp_time = np.asarray(exp_time) categories = sorted(set(exp_time[~np.isnan(exp_time)])) if len(categories) > 100: raise ValueError( f"Unable to convert `adata.obs[{key!r}]` to `categorical` since it " f"would create `{len(categories)}` categories." ) exp_time = pd.Series( pd.Categorical( exp_time, categories=categories, ordered=True, ) ) if not exp_time.cat.ordered: logg.warning("Categories are not ordered. Using ascending order") exp_time.cat = exp_time.cat.as_ordered() exp_time = pd.Series(pd.Categorical(exp_time, ordered=True), index=adata.obs_names) if exp_time.isnull().any(): raise ValueError("Series contains NaN value(s).") n_cats = len(exp_time.cat.categories) if n_cats < 2: raise ValueError(f"Expected to find at least `2` categories, found `{n_cats}`.") return exp_time
def write(self, fname: Union[str, Path], ext: Optional[str] = "pickle") -> None: """ %(pickleable.full_desc)s Parameters ---------- %(pickleable.parameters)s Returns ------- %(pickleable.returns)s """ # noqa fname = str(fname) if ext is not None: if not ext.startswith("."): ext = "." + ext if not fname.endswith(ext): fname += ext logg.debug(f"Writing to `{fname}`") with open(fname, "wb") as fout: if version_info[:2] > (3, 6): pickle.dump(self, fout) else: # we need to use PrecomputedKernel because Python3.6 can't pickle Enums # and they are present in VelocityKernel logg.warning( "Saving kernel as `cellrank.tl.kernels.PrecomputedKernel`") orig_kernel = self.kernel self._kernel = PrecomputedKernel(self.kernel) try: pickle.dump(self, fout) except Exception as e: raise e finally: self._kernel = orig_kernel
def _write_graph_data( adata: AnnData, data: Union[np.ndarray, spmatrix], key: str, ): """ Write graph data to :mod:`AnnData`. :module`anndata` >=0.7 stores `(n_obs x n_obs)` matrices in `.obsp` rather than `.uns`. This is for backward compatibility. Parameters ---------- adata Annotated data object. data The graph data we want to write. key Key from either ``adata.uns`` or `adata.obsp``. Returns -------- None Nothing, just writes the data. """ try: adata.obsp[key] = data write_to = "obsp" if data.shape[0] != data.shape[1]: logg.warning( f"`adata.obsp` attribute should only contain square matrices, found shape `{data.shape}`" ) except AttributeError: adata.uns[key] = data write_to = "uns" logg.debug(f"Writing graph data to `adata.{write_to}[{key!r}]`")
def _read_graph_data(adata: AnnData, key: str) -> Union[np.ndarray, spmatrix]: """ Read graph data from :mod:`anndata`. Parameters ---------- adata Annotated data object. key Key in ``adata.obsp``. Returns ------- :class:`numpy.ndarray` or :class:`scipy.sparse.spmatrix` The graph data. """ logg.debug(f"Reading key `{key!r}` from `adata.obsp`") if key in adata.obsp.keys(): return adata.obsp[key] raise KeyError(f"Unable to find key `{key!r}` in `adata.obsp`.")
def compute_transition_matrix( self, density_normalize: bool = True ) -> "ConnectivityKernel": """ Compute transition matrix based on transcriptomic similarity. Uses symmetric, weighted KNN graph to compute symmetric transition matrix. The connectivities are computed using :func:`scanpy.pp.neighbors`. Depending on the parameters used there, they can be UMAP connectivities or gaussian-kernel-based connectivities with adaptive kernel width. Parameters ---------- density_normalize Whether or not to use the underlying KNN graph for density normalization. Returns ------- :class:`cellrank.tl.kernels.ConnectivityKernel` Makes :paramref:`transition_matrix` available. """ start = logg.info("Computing transition matrix based on connectivities") params = {"dnorm": density_normalize} if params == self.params: assert self.transition_matrix is not None, _ERROR_EMPTY_CACHE_MSG logg.debug(_LOG_USING_CACHE) logg.info(" Finish", time=start) return self self._params = params self._compute_transition_matrix( matrix=self._conn.copy(), density_normalize=density_normalize ) logg.info(" Finish", time=start) return self
def _read_from_adata( self, conn_key: Optional[str] = "connectivities", read_conn: bool = True, **kwargs: Any, ) -> None: """ Import the base-KNN graph and optionally check for symmetry and connectivity. Parameters ---------- conn_key Key in :attr:`anndata.AnnData.uns` where connectivities are stored. read_conn Whether to read connectivities or set them to `None`. Useful when not exposing density normalization or when KNN connectivities are not used to compute the transition matrix. kwargs Additional keyword arguments. Returns ------- Nothing, just sets :attr:`_conn`. """ if not read_conn: self._conn = None return self._conn = _get_neighs(self.adata, mode="connectivities", key=conn_key).astype(_dtype) check_connectivity = kwargs.pop("check_connectivity", False) if check_connectivity: start = logg.debug("Checking the KNN graph for connectedness") if not _connected(self._conn): logg.warning("KNN graph is not connected", time=start) else: logg.debug("KNN graph is connected", time=start) start = logg.debug("Checking the KNN graph for symmetry") if not _symmetric(self._conn): logg.warning("KNN graph is not symmetric", time=start) else: logg.debug("KNN graph is symmetric", time=start)
def _read_from_adata(self, **kwargs): """Import the base-KNN graph and optionally check for symmetry and connectivity.""" if not _has_neighs(self.adata): raise KeyError("Compute KNN graph first as `scanpy.pp.neighbors()`.") self._conn = _get_neighs(self.adata, "connectivities").astype(_dtype) check_connectivity = kwargs.pop("check_connectivity", False) if check_connectivity: start = logg.debug("Checking the KNN graph for connectedness") if not _connected(self._conn): logg.warning("KNN graph is not connected", time=start) else: logg.debug("KNN graph is connected", time=start) start = logg.debug("Checking the KNN graph for symmetry") if not _symmetric(self._conn): logg.warning("KNN graph is not symmetric", time=start) else: logg.debug("KNN graph is symmetric", time=start)
def compute_terminal_states( self, use: Optional[Union[int, Tuple[int], List[int], range]] = None, percentile: Optional[int] = 98, method: str = "kmeans", cluster_key: Optional[str] = None, n_clusters_kmeans: Optional[int] = None, n_neighbors: int = 20, resolution: float = 0.1, n_matches_min: Optional[int] = 0, n_neighbors_filtering: int = 15, basis: Optional[str] = None, n_comps: int = 5, scale: bool = False, en_cutoff: Optional[float] = 0.7, p_thresh: float = 1e-15, ) -> None: """ Find approximate recurrent classes of the Markov chain. Filter to obtain recurrent states in left eigenvectors. Cluster to obtain approximate recurrent classes in right eigenvectors. Parameters ---------- use Which or how many first eigenvectors to use as features for clustering/filtering. If `None`, use the `eigengap` statistic. percentile Threshold used for filtering out cells which are most likely transient states. Cells which are in the lower ``percentile`` percent of each eigenvector will be removed from the data matrix. method Method to be used for clustering. Must be one of `'louvain'`, `'leiden'` or `'kmeans'`. cluster_key If a key to cluster labels is given, :attr:`{fs}` will get associated with these for naming and colors. n_clusters_kmeans If `None`, this is set to ``use + 1``. n_neighbors If we use `'louvain'` or `'leiden'` for clustering cells, we need to build a KNN graph. This is the :math:`K` parameter for that, the number of neighbors for each cell. resolution Resolution parameter for `'louvain'` or `'leiden'` clustering. Should be chosen relatively small. n_matches_min Filters out cells which don't have at least n_matches_min neighbors from the same class. This filters out some cells which are transient but have been misassigned. n_neighbors_filtering Parameter for filtering cells. Cells are filtered out if they don't have at least ``n_matches_min`` neighbors among their ``n_neighbors_filtering`` nearest cells. basis Key from :paramref`adata` ``.obsm`` to be used as additional features for the clustering. n_comps Number of embedding components to be use when ``basis`` is not `None`. scale Scale to z-scores. Consider using this if appending embedding to features. %(en_cutoff_p_thresh)s Returns ------- None Nothing, but updates the following fields: - :attr:`{fsp}` - :attr:`{fs}` """ def _compute_macrostates_prob() -> Series: """Compute a global score of being an approximate recurrent class.""" # get the truncated eigendecomposition V, evals = eig["V_l"].real[:, use], eig["D"].real[use] # shift and scale V_pos = np.abs(V) V_shifted = V_pos - np.min(V_pos, axis=0) V_scaled = V_shifted / np.max(V_shifted, axis=0) # check the ranges are correct assert np.allclose(np.min(V_scaled, axis=0), 0), "Lower limit it not zero." assert np.allclose(np.max(V_scaled, axis=0), 1), "Upper limit is not one." # further scale by the eigenvalues V_eigs = V_scaled / evals # sum over cols and scale c_ = np.sum(V_eigs, axis=1) c = c_ / np.max(c_) return Series(c, index=self.adata.obs_names) def check_use(use) -> List[int]: if method not in ["kmeans", "louvain", "leiden"]: raise ValueError( f"Invalid method `{method!r}`. Valid options are `'louvain'`, `'leiden'` or `'kmeans'`." ) if use is None: use = eig["eigengap"] + 1 # add one b/c indexing starts at 0 if isinstance(use, int): use = list(range(use)) elif not isinstance(use, (tuple, list, range)): raise TypeError( f"Argument `use` must be either `int`, `tuple`, `list` or `range`, " f"found `{type(use).__name__!r}`." ) else: if not all(map(lambda u: isinstance(u, int), use)): raise TypeError("Not all values in `use` argument are integers.") use = list(use) if len(use) == 0: raise ValueError( f"Number of eigenvector must be larger than `0`, found `{len(use)}`." ) muse = max(use) if muse >= eig["V_l"].shape[1] or muse >= eig["V_r"].shape[1]: raise ValueError( f"Maximum specified eigenvector `{muse}` is larger " f'than the number of computed eigenvectors `{eig["V_l"].shape[1]}`. ' f"Use `.compute_eigendecomposition(k={muse})` to recompute the eigendecomposition." ) return use eig = self._get(P.EIG) if eig is None: raise RuntimeError( "Compute eigendecomposition first as `.compute_eigendecomposition()`." ) use = check_use(use) start = logg.info("Computing approximate recurrent classes") # we check for complex values only in the left, that's okay because the complex pattern # will be identical for left and right V_l, V_r = eig["V_l"][:, use], eig["V_r"].real[:, use] V_l = _complex_warning(V_l, use, use_imag=False) # compute a rc probability logg.debug("Computing probabilities of approximate recurrent classes") self._set(A.TERM_PROBS, _compute_macrostates_prob()) # retrieve embedding and concatenate if basis is not None: bkey = f"X_{basis}" if bkey not in self.adata.obsm.keys(): raise KeyError(f"Basis key `{bkey!r}` not found in `adata.obsm`") X_em = self.adata.obsm[bkey][:, :n_comps] X = np.concatenate([V_r, X_em], axis=1) else: logg.debug("Basis is `None`. Setting X equal to the right eigenvectors") X = V_r # filter out cells which are in the lowest q percentile in abs value in each eigenvector if percentile is not None: logg.debug("Filtering out cells according to percentile") if percentile < 0 or percentile > 100: raise ValueError( f"Percentile must be in interval `[0, 100]`, found `{percentile}`." ) cutoffs = np.percentile(np.abs(V_l), percentile, axis=0) ixs = np.sum(np.abs(V_l) < cutoffs, axis=1) < V_l.shape[1] X = X[ixs, :] # scale if scale: X = zscore(X, axis=0) # cluster X if method == "kmeans" and n_clusters_kmeans is None: n_clusters_kmeans = len(use) + (percentile is None) if X.shape[0] < n_clusters_kmeans: raise ValueError( f"Filtering resulted in only {X.shape[0]} cell(s), insufficient to cluster into " f"`{n_clusters_kmeans}` clusters. Consider decreasing the value of `percentile`." ) logg.debug( f"Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering" ) labels = _cluster_X( X, method=method, n_clusters=n_clusters_kmeans, n_neighbors=n_neighbors, resolution=resolution, ) # fill in the labels in case we filtered out cells before if percentile is not None: rc_labels = np.repeat(None, self.adata.n_obs) rc_labels[ixs] = labels else: rc_labels = labels rc_labels = Series(rc_labels, index=self.adata.obs_names, dtype="category") rc_labels.cat.categories = list(rc_labels.cat.categories.astype("str")) # filtering to get rid of some of the left over transient states if n_matches_min > 0: logg.debug(f"Filtering according to `n_matches_min={n_matches_min}`") distances = _get_connectivities( self.adata, mode="distances", n_neighbors=n_neighbors_filtering ) rc_labels = _filter_cells( distances, rc_labels=rc_labels, n_matches_min=n_matches_min ) self.set_terminal_states( labels=rc_labels, cluster_key=cluster_key, en_cutoff=en_cutoff, p_thresh=p_thresh, add_to_existing=False, time=start, )
def cluster_lineage( adata: AnnData, model: _input_model_type, genes: Sequence[str], lineage: str, backward: bool = False, time_range: _time_range_type = None, clusters: Optional[Sequence[str]] = None, n_points: int = 200, time_key: str = "latent_time", norm: bool = True, recompute: bool = False, callback: _callback_type = None, ncols: int = 3, sharey: Union[str, bool] = False, key: Optional[str] = None, random_state: Optional[int] = None, use_leiden: bool = False, show_progress_bar: bool = True, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, pca_kwargs: Dict = MappingProxyType({"svd_solver": "arpack"}), neighbors_kwargs: Dict = MappingProxyType({"use_rep": "X"}), clustering_kwargs: Dict = MappingProxyType({}), return_models: bool = False, **kwargs, ) -> Optional[_return_model_type]: """ Cluster gene expression trends within a lineage and plot the clusters. This function is based on Palantir, see [Setty19]_. It can be used to discover modules of genes that drive development along a given lineage. Consider running this function on a subset of genes which are potential lineage drivers, identified e.g. by running :func:`cellrank.tl.lineage_drivers`. Parameters ---------- %(adata)s %(model)s %(genes)s lineage Name of the lineage for which to cluster the genes. %(backward)s %(time_ranges)s clusters Cluster identifiers to plot. If `None`, all clusters will be considered. Useful when plotting previously computed clusters. n_points Number of points used for prediction. time_key Key in ``adata.obs`` where the pseudotime is stored. norm Whether to z-normalize each trend to have zero mean, unit variance. recompute If `True`, recompute the clustering, otherwise try to find already existing one. %(model_callback)s ncols Number of columns for the plot. sharey Whether to share y-axis across multiple plots. key Key in ``adata.uns`` where to save the results. If `None`, it will be saved as ``lineage_{lineage}_trend`` . random_state Random seed for reproducibility. use_leiden Whether to use :func:`scanpy.tl.leiden` for clustering or :func:`scanpy.tl.louvain`. %(parallel)s %(plotting)s pca_kwargs Keyword arguments for :func:`scanpy.pp.pca`. neighbors_kwargs Keyword arguments for :func:`scanpy.pp.neighbors`. clustering_kwargs Keyword arguments for :func:`scanpy.tl.louvain` or :func:`scanpy.tl.leiden`. %(return_models)s **kwargs: Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(plots_or_returns_models)s Also updates ``adata.uns`` with the following: - ``key`` or ``lineage_{lineage}_trend`` - an :class:`anndata.AnnData` object of shape `(n_genes, n_points)` containing the clustered genes. """ import scanpy as sc from anndata import AnnData as _AnnData lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") _ = adata.obsm[lineage_key][lineage] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", kwargs.get("use_raw", False)) if key is None: key = f"lineage_{lineage}_trend" if recompute or key not in adata.uns: kwargs["backward"] = backward kwargs["time_key"] = time_key kwargs["n_test_points"] = n_points models = _create_models(model, genes, [lineage]) all_models, models, genes, _ = _fit_bulk( models, _create_callbacks(adata, callback, genes, [lineage], **kwargs), genes, lineage, time_range, return_models=True, # always return (better error messages) filter_all_failed=True, parallel_kwargs={ "show_progress_bar": show_progress_bar, "n_jobs": _get_n_cores(n_jobs, len(genes)), "backend": _get_backend(models, backend), }, **kwargs, ) # `n_genes, n_test_points` trends = np.vstack( [model[lineage].y_test for model in models.values()]).T if norm: logg.debug("Normalizing trends") _ = StandardScaler(copy=False).fit_transform(trends) trends = _AnnData(trends.T) trends.obs_names = genes # sanity check if trends.n_obs != len(genes): raise RuntimeError( f"Expected to find `{len(genes)}` genes, found `{trends.n_obs}`." ) if trends.n_vars != n_points: raise RuntimeError( f"Expected to find `{n_points}` points, found `{trends.n_vars}`." ) random_state = np.random.mtrand.RandomState(random_state).randint( 2**16) pca_kwargs = dict(pca_kwargs) pca_kwargs.setdefault("n_comps", min(50, n_points, len(genes)) - 1) pca_kwargs.setdefault("random_state", random_state) sc.pp.pca(trends, **pca_kwargs) neighbors_kwargs = dict(neighbors_kwargs) neighbors_kwargs.setdefault("random_state", random_state) sc.pp.neighbors(trends, **neighbors_kwargs) clustering_kwargs = dict(clustering_kwargs) clustering_kwargs["key_added"] = "clusters" clustering_kwargs.setdefault("random_state", random_state) try: if use_leiden: sc.tl.leiden(trends, **clustering_kwargs) else: sc.tl.louvain(trends, **clustering_kwargs) except ImportError as e: logg.warning(str(e)) if use_leiden: sc.tl.louvain(trends, **clustering_kwargs) else: sc.tl.leiden(trends, **clustering_kwargs) logg.info(f"Saving data to `adata.uns[{key!r}]`") adata.uns[key] = trends else: all_models = None logg.info(f"Loading data from `adata.uns[{key!r}]`") trends = adata.uns[key] if "clusters" not in trends.obs: raise KeyError( "Unable to find the clustering in `trends.obs['clusters']`.") if clusters is None: clusters = trends.obs["clusters"].cat.categories for c in clusters: if c not in trends.obs["clusters"].cat.categories: raise ValueError( f"Invalid cluster name `{c!r}`. " f"Valid options are `{list(trends.obs['clusters'].cat.categories)}`." ) nrows = int(np.ceil(len(clusters) / ncols)) fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * 10, nrows * 10) if figsize is None else figsize, sharey=sharey, dpi=dpi, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) j = 0 for j, (ax, c) in enumerate(zip(axes, clusters)): # noqa data = trends[trends.obs["clusters"] == c].X mean, sd = np.mean(data, axis=0), np.var(data, axis=0) sd = np.sqrt(sd) for i in range(data.shape[0]): ax.plot(data[i], color="gray", lw=0.5) ax.plot(mean, lw=2, color="black") ax.plot(mean - sd, lw=1.5, color="black", linestyle="--") ax.plot(mean + sd, lw=1.5, color="black", linestyle="--") ax.fill_between(range(len(mean)), mean - sd, mean + sd, color="black", alpha=0.1) ax.set_title(f"Cluster {c}") ax.set_xticks([]) if not sharey: ax.set_yticks([]) for j in range(j + 1, len(axes)): axes[j].remove() if save is not None: save_fig(fig, save) if return_models: return all_models
def heatmap( adata: AnnData, model: _input_model_type, genes: Sequence[str], lineages: Optional[Union[str, Sequence[str]]] = None, backward: bool = False, mode: str = HeatmapMode.LINEAGES.s, time_key: str = "latent_time", time_range: Optional[Union[_time_range_type, List[_time_range_type]]] = None, callback: _callback_type = None, cluster_key: Optional[Union[str, Sequence[str]]] = None, show_absorption_probabilities: bool = False, cluster_genes: bool = False, keep_gene_order: bool = False, scale: bool = True, n_convolve: Optional[int] = 5, show_all_genes: bool = False, cbar: bool = True, lineage_height: float = 0.33, fontsize: Optional[float] = None, xlabel: Optional[str] = None, cmap: mcolors.ListedColormap = cm.viridis, dendrogram: bool = True, return_genes: bool = False, return_models: bool = False, n_jobs: Optional[int] = 1, backend: str = _DEFAULT_BACKEND, show_progress_bar: bool = True, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, save: Optional[Union[str, Path]] = None, **kwargs, ) -> Optional[Union[Dict[str, pd.DataFrame], Tuple[_return_model_type, Dict[ str, pd.DataFrame]]]]: """ Plot a heatmap of smoothed gene expression along specified lineages. Parameters ---------- %(adata)s %(model)s %(genes)s lineages Names of the lineages for which to plot. If `None`, plot all lineages. %(backward)s mode Valid options are: - `{m.LINEAGES.s!r}` - group by ``genes`` for each lineage in ``lineages``. - `{m.GENES.s!r}` - group by ``lineages`` for each gene in ``genes``. time_key Key in ``adata.obs`` where the pseudotime is stored. %(time_ranges)s %(model_callback)s cluster_key Key(s) in ``adata.obs`` containing categorical observations to be plotted on top of the heatmap. Only available when ``mode={m.LINEAGES.s!r}``. show_absorption_probabilities Whether to also plot absorption probabilities alongside the smoothed expression. Only available when ``mode={m.LINEAGES.s!r}``. cluster_genes Whether to cluster genes using :func:`seaborn.clustermap` when ``mode='lineages'``. keep_gene_order Whether to keep the gene order for later lineages after the first was sorted. Only available when ``cluster_genes=False`` and ``mode={m.LINEAGES.s!r}``. scale Whether to normalize the gene expression `0-1` range. n_convolve Size of the convolution window when smoothing absorption probabilities. show_all_genes Whether to show all genes on y-axis. cbar Whether to show the colorbar. lineage_height Height of a bar when ``mode={m.GENES.s!r}``. fontsize Size of the title's font. xlabel Label on the x-axis. If `None`, it is determined based on ``time_key``. cmap Colormap to use when visualizing the smoothed expression. dendrogram Whether to show dendrogram when ``cluster_genes=True``. return_genes Whether to return the sorted or clustered genes. Only available when ``mode={m.LINEAGES.s!r}``. %(return_models)s %(parallel)s %(plotting)s kwargs Keyword arguments for :meth:`cellrank.ul.models.BaseModel.prepare`. Returns ------- %(plots_or_returns_models)s :class:`pandas.DataFrame` If ``return_genes=True`` and ``mode={m.LINEAGES.s!r}``, returns :class:`pandas.DataFrame` containing the clustered or sorted genes. """ import seaborn as sns def find_indices(series: pd.Series, values) -> Tuple[Any]: def find_nearest(array: np.ndarray, value: float) -> int: ix = np.searchsorted(array, value, side="left") if ix > 0 and (ix == len(array) or fabs(value - array[ix - 1]) < fabs(value - array[ix])): return ix - 1 return ix series = series[np.argsort(series.values)] return tuple(series[[find_nearest(series.values, v) for v in values]].index) def subset_lineage(lname: str, rng: np.ndarray) -> np.ndarray: time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) lin = adata[ixs, :].obsm[lineage_key][lname] lin = lin.X.copy().squeeze() if n_convolve is not None: lin = convolve(lin, np.ones(n_convolve) / n_convolve, mode="nearest") return lin def create_col_colors(lname: str, rng: np.ndarray) -> Tuple[np.ndarray, Cmap, Norm]: color = adata.obsm[lineage_key][lname].colors[0] lin = subset_lineage(lname, rng) h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color)) end_color = mcolors.hsv_to_rgb([h, 1, v]) lineage_cmap = mcolors.LinearSegmentedColormap.from_list( "lineage_cmap", ["#ffffff", end_color], N=len(rng)) norm = mcolors.Normalize(vmin=np.min(lin), vmax=np.max(lin)) scalar_map = cm.ScalarMappable(cmap=lineage_cmap, norm=norm) return ( np.array([mcolors.to_hex(c) for c in scalar_map.to_rgba(lin)]), lineage_cmap, norm, ) def create_col_categorical_color(cluster_key: str, rng: np.ndarray) -> np.ndarray: if not is_categorical_dtype(adata.obs[cluster_key]): raise TypeError( f"Expected `adata.obs[{cluster_key!r}]` to be categorical, " f"found `{adata.obs[cluster_key].dtype.name!r}`.") color_key = f"{cluster_key}_colors" if color_key not in adata.uns: logg.warning( f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors" ) colors = _create_categorical_colors( len(adata.obs[cluster_key].cat.categories)) adata.uns[color_key] = colors else: colors = adata.uns[color_key] time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors)) return np.array([ mcolors.to_hex(mapper[v]) for v in adata[ixs, :].obs[cluster_key].values ]) def create_cbar( ax, x_delta: float, cmap: Cmap, norm: Norm, label: Optional[str] = None, ) -> Ax: cax = inset_axes( ax, width="1%", height="100%", loc="lower right", bbox_to_anchor=(x_delta, 0, 1, 1), bbox_transform=ax.transAxes, ) _ = mpl.colorbar.ColorbarBase( cax, cmap=cmap, norm=norm, label=label, ticks=np.linspace(norm.vmin, norm.vmax, 5), ) return cax @valuedispatch def _plot_heatmap(_mode: HeatmapMode) -> Fig: pass @_plot_heatmap.register(HeatmapMode.GENES) def _() -> Tuple[Fig, None]: def color_fill_rec(ax, xs, y1, y2, colors=None, cmap=cmap, **kwargs) -> None: colors = colors if cmap is None else cmap(colors) x = 0 for i, (color, x, y1, y2) in enumerate(zip(colors, xs, y1, y2)): dx = (xs[i + 1] - xs[i]) if i < len(x) else (xs[-1] - xs[-2]) ax.add_patch( plt.Rectangle((x, y1), dx, y2 - y1, color=color, ec=color, **kwargs)) ax.plot(x, y2, lw=0) fig, axes = plt.subplots( nrows=len(genes) + show_absorption_probabilities, figsize=(12, len(genes) + len(lineages) * lineage_height) if figsize is None else figsize, dpi=dpi, constrained_layout=True, ) if not isinstance(axes, Iterable): axes = [axes] axes = np.ravel(axes) if show_absorption_probabilities: data["absorption probability"] = data[next(iter(data.keys()))] for ax, (gene, models) in zip(axes, data.items()): if scale: vmin, vmax = 0, 1 else: c = np.array([m.y_test for m in models.values()]) vmin, vmax = np.nanmin(c), np.nanmax(c) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) ix = 0 ys = [ix] if gene == "absorption probability": norm = mcolors.Normalize(vmin=0, vmax=1) for ln, x in ((ln, m.x_test) for ln, m in models.items()): y = np.ones_like(x) c = subset_lineage(ln, x.squeeze()) color_fill_rec(ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)) ix += lineage_height ys.append(ix) else: for x, c in ((m.x_test, m.y_test) for m in models.values()): y = np.ones_like(x) c = _min_max_scale(c) if scale else c color_fill_rec(ax, x, y * ix, y * (ix + lineage_height), colors=norm(c)) ix += lineage_height ys.append(ix) xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.min(xs), np.max(xs) ax.set_xticks(np.linspace(x_min, x_max, _N_XTICKS)) ax.set_yticks(np.array(ys[:-1]) + lineage_height / 2) ax.spines["left"].set_position( ("data", 0) ) # move the left spine to the rectangles to get nicer yticks ax.set_yticklabels(models.keys(), ha="right") ax.set_title(gene, fontdict={"fontsize": fontsize}) ax.set_ylabel("lineage") for pos in ["top", "bottom", "left", "right"]: ax.spines[pos].set_visible(False) if cbar: cax, _ = mpl.colorbar.make_axes(ax) _ = mpl.colorbar.ColorbarBase( cax, ticks=np.linspace(vmin, vmax, 5), norm=norm, cmap=cmap, label="value" if gene == "absorption probability" else "scaled expression" if scale else "expression", ) ax.tick_params( top=False, bottom=False, left=True, right=False, labelleft=True, labelbottom=False, ) ax.xaxis.set_major_formatter(FormatStrFormatter("%.3f")) ax.tick_params( top=False, bottom=True, left=True, right=False, labelleft=True, labelbottom=True, ) ax.set_xlabel(xlabel) return fig, None @_plot_heatmap.register(HeatmapMode.LINEAGES) def _() -> Tuple[List[Fig], pd.DataFrame]: data_t = defaultdict(dict) # transpose for gene, lns in data.items(): for ln, y in lns.items(): data_t[ln][gene] = y figs = [] gene_order = None sorted_genes = pd.DataFrame() if return_genes else None for lname, models in data_t.items(): xs = np.array([m.x_test for m in models.values()]) x_min, x_max = np.nanmin(xs), np.nanmax(xs) df = pd.DataFrame([m.y_test for m in models.values()], index=models.keys()) df.index.name = "genes" if not cluster_genes: if gene_order is not None: df = df.loc[gene_order] else: max_sort = np.argsort( np.argmax(df.apply(_min_max_scale, axis=1).values, axis=1)) df = df.iloc[max_sort, :] if keep_gene_order: gene_order = df.index cat_colors = None if cluster_key is not None: cat_colors = np.stack( [ create_col_categorical_color( c, np.linspace(x_min, x_max, df.shape[1])) for c in cluster_key ], axis=0, ) if show_absorption_probabilities: col_colors, col_cmap, col_norm = create_col_colors( lname, np.linspace(x_min, x_max, df.shape[1])) if cat_colors is not None: col_colors = np.vstack([cat_colors, col_colors[None, :]]) else: col_colors, col_cmap, col_norm = cat_colors, None, None row_cluster = cluster_genes and df.shape[0] > 1 show_clust = row_cluster and dendrogram g = sns.clustermap( df, cmap=cmap, figsize=(10, min(len(genes) / 8 + 1, 10)) if figsize is None else figsize, xticklabels=False, row_cluster=row_cluster, col_colors=col_colors, colors_ratio=0, col_cluster=False, cbar_pos=None, yticklabels=show_all_genes or "auto", standard_scale=0 if scale else None, ) if cbar: cax = create_cbar( g.ax_heatmap, 0.1, cmap=cmap, norm=mcolors.Normalize( vmin=0 if scale else np.min(df.values), vmax=1 if scale else np.max(df.values), ), label="scaled expression" if scale else "expression", ) g.fig.add_axes(cax) if col_cmap is not None and col_norm is not None: cax = create_cbar( g.ax_heatmap, 0.25, cmap=col_cmap, norm=col_norm, label="absorption probability", ) g.fig.add_axes(cax) if g.ax_col_colors: main_bbox = _get_ax_bbox(g.fig, g.ax_heatmap) n_bars = show_absorption_probabilities + ( len(cluster_key) if cluster_key is not None else 0) _set_ax_height_to_cm( g.fig, g.ax_col_colors, height=min( 5, max(n_bars * main_bbox.height / len(df), 0.25 * n_bars)), ) g.ax_col_colors.set_title(lname, fontdict={"fontsize": fontsize}) else: g.ax_heatmap.set_title(lname, fontdict={"fontsize": fontsize}) g.ax_col_dendrogram.set_visible( False) # gets rid of top free space g.ax_heatmap.yaxis.tick_left() g.ax_heatmap.yaxis.set_label_position("right") g.ax_heatmap.set_xlabel(xlabel) g.ax_heatmap.set_xticks(np.linspace(0, len(df.columns), _N_XTICKS)) g.ax_heatmap.set_xticklabels( list( map(lambda n: round(n, 3), np.linspace(x_min, x_max, _N_XTICKS)))) if show_clust: # robustly show dendrogram, because gene names can be long g.ax_row_dendrogram.set_visible(True) dendro_box = g.ax_row_dendrogram.get_position() pad = 0.005 bb = g.ax_heatmap.yaxis.get_tightbbox( g.fig.canvas.get_renderer()).transformed( g.fig.transFigure.inverted()) dendro_box.x0 = bb.x0 - dendro_box.width - pad dendro_box.x1 = bb.x0 - pad g.ax_row_dendrogram.set_position(dendro_box) else: g.ax_row_dendrogram.set_visible(False) if return_genes: sorted_genes[lname] = (df.index[g.dendrogram_row.reordered_ind] if hasattr(g, "dendrogram_row") and g.dendrogram_row is not None else df.index) figs.append(g) return figs, sorted_genes mode = HeatmapMode(mode) lineage_key = str(AbsProbKey.BACKWARD if backward else AbsProbKey.FORWARD) if lineage_key not in adata.obsm: raise KeyError( f"Lineages key `{lineage_key!r}` not found in `adata.obsm`.") if lineages is None: lineages = adata.obsm[lineage_key].names elif isinstance(lineages, str): lineages = [lineages] lineages = _unique_order_preserving(lineages) _ = adata.obsm[lineage_key][lineages] if cluster_key is not None: if isinstance(cluster_key, str): cluster_key = [cluster_key] cluster_key = _unique_order_preserving(cluster_key) if isinstance(genes, str): genes = [genes] genes = _unique_order_preserving(genes) _check_collection(adata, genes, "var_names", use_raw=kwargs.get("use_raw", False)) kwargs["backward"] = backward kwargs["time_key"] = time_key models = _create_models(model, genes, lineages) all_models, data, genes, lineages = _fit_bulk( models, _create_callbacks(adata, callback, genes, lineages, **kwargs), genes, lineages, time_range, return_models=True, # always return (better error messages) filter_all_failed=True, parallel_kwargs={ "show_progress_bar": show_progress_bar, "n_jobs": _get_n_cores(n_jobs, len(genes)), "backend": _get_backend(models, backend), }, **kwargs, ) xlabel = time_key if xlabel is None else xlabel logg.debug(f"Plotting `{mode.s!r}` heatmap") fig, genes = _plot_heatmap(mode) if save is not None and fig is not None: if not isinstance(fig, Iterable): save_fig(fig, save) elif len(fig) == 1: save_fig(fig[0], save) else: for ln, f in zip(lineages, fig): save_fig(f, os.path.join(save, f"lineage_{ln}")) if return_genes and mode == HeatmapMode.LINEAGES: return (all_models, genes) if return_models else genes elif return_models: return all_models