def create_ixs(ixs: Indices_t, *, kind: str) -> Optional[np.ndarray]: if ixs is None: return None if isinstance(ixs, dict): # fmt: off if len(ixs) != 1: raise ValueError( f"Expected to find only 1 cluster key, found `{len(ixs)}`." ) cluster_key = next(iter(ixs.keys())) if cluster_key not in self.adata.obs: raise KeyError( f"Unable to find `adata.obs[{cluster_key!r}]`.") if not is_categorical_dtype(self.adata.obs[cluster_key]): raise TypeError( f"Expected `adata.obs[{cluster_key!r}]` to be categorical, " f"found `{infer_dtype(self.adata.obs[cluster_key])}`.") ixs = np.where( np.isin(self.adata.obs[cluster_key], ixs[cluster_key]))[0] # fmt: on elif isinstance(ixs, str): ixs = np.where(self.adata.obs_names == ixs)[0] else: ixs = np.where(np.isin(self.adata.obs_names, ixs))[0] if not len(ixs): logg.warning( f"No {kind} indices have been selected, using `None`") return None return ixs
def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str): self._set_or_debug(obsm_key, self.adata.obsm, attr) names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns) colors = self._set_or_debug(_colors(self._term_key), self.adata.uns) probs = self._get(attr) if probs is not None: if len(names) != probs.shape[1]: if isinstance(probs, Lineage): names = probs.names else: logg.warning( f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new names" ) names = [f"Lineage {i}" for i in range(probs.shape[1])] if len(colors) != probs.shape[1] or not all( map(lambda c: isinstance(c, str) and is_color_like(c), colors) ): if isinstance(probs, Lineage): colors = probs.colors else: logg.warning( f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. " f"Creating new colors" ) colors = _create_categorical_colors(probs.shape[1]) self._set(attr, Lineage(probs, names=names, colors=colors)) self.adata.obsm[obsm_key] = self._get(attr) self.adata.uns[_lin_names(self._term_key)] = names self.adata.uns[_colors(self._term_key)] = colors
def _detect_cc_stages(self, rc_labels: Series, p_thresh: float = 1e-15) -> None: """ Detect cell-cycle driven start or endpoints. Parameters ---------- rc_labels Approximate recurrent classes. p_thresh P-value threshold for the rank-sum test for the group to be considered cell-cycle driven. Returns ------- None Nothing, but warns if a group is cell-cycle driven. """ # initialize the groups (start or end clusters) and scores groups = rc_labels.cat.categories scores = [] if self._G2M_score is not None: scores.append(self._G2M_score) if self._S_score is not None: scores.append(self._S_score) # loop over groups and scores for group in groups: mask = rc_labels == group for score in scores: a, b = score[mask], score[~mask] statistic, pvalue = ranksums(a, b) if statistic > 0 and pvalue < p_thresh: logg.warning(f"Group `{group!r}` appears to be cell-cycle driven") break
def __init__( self, obj: Union[AnnData, np.ndarray, spmatrix, KernelExpression], key: Optional[str] = None, obsp_key: Optional[str] = None, write_to_adata: bool = True, ): if isinstance(obj, KernelExpression): self._kernel = obj elif isinstance(obj, (np.ndarray, spmatrix)): self._kernel = PrecomputedKernel(obj) elif isinstance(obj, AnnData): if obsp_key is None: raise ValueError( "Specify `obsp_key=...` when supplying an `AnnData` object." ) elif obsp_key not in obj.obsp.keys(): raise KeyError( f"Key `{obsp_key!r}` not found in `adata.obsp`.") self._kernel = PrecomputedKernel(obj.obsp[obsp_key], adata=obj) else: raise TypeError( f"Expected an object of type `KernelExpression`, `numpy.ndarray`, `scipy.sparse.spmatrix` " f"or `anndata.AnnData`, got `{type(obj).__name__!r}`.") if self.kernel._transition_matrix is None: # access the private attribute to avoid accidentally computing the transition matrix # in principle, it doesn't make a difference, apart from not seeing the message logg.warning( "Computing transition matrix using the default parameters") self.kernel.compute_transition_matrix() if write_to_adata: self.kernel.write_to_adata(key=key)
def _fit_terminal_states( self, n_lineages: Optional[int] = None, cluster_key: Optional[str] = None, method: str = "krylov", **kwargs, ) -> None: if n_lineages is None or n_lineages == 1: self.compute_eigendecomposition() if n_lineages is None: n_lineages = self.eigendecomposition["eigengap"] + 1 if n_lineages > 1: self.compute_schur(n_lineages, method=method) try: self.compute_macrostates(n_states=n_lineages, cluster_key=cluster_key, **kwargs) except ValueError: logg.warning( f"Computing `{n_lineages}` macrostates cuts through a block of complex conjugates. " f"Increasing `n_lineages` to {n_lineages + 1}") self.compute_macrostates(n_states=n_lineages + 1, cluster_key=cluster_key, **kwargs) fs_kwargs = { "n_cells": kwargs["n_cells"] } if "n_cells" in kwargs else {} if n_lineages is None: self.compute_terminal_states(method="eigengap", **fs_kwargs) else: self.set_terminal_states_from_macrostates(**fs_kwargs)
def _(self, data: pd.Series, prop: str, discrete: bool = False, **kwargs) -> None: if discrete and kwargs.get("mode", "embedding") == "time": logg.warning( "`mode='time'` is implemented in continuous case, plotting in continuous mode" ) discrete = False if discrete: self._plot_discrete(data, prop, **kwargs) elif prop == P.MACRO.v: # GPCCA prop = P.MACRO_MEMBER.v self._plot_continuous(getattr(self, prop, None), prop, **kwargs) elif prop == P.TERM.v: probs = getattr(self, A.TERM_ABS_PROBS.s, None) # we have this only in GPCCA if isinstance(probs, Lineage): self._plot_continuous(probs, P.TERM_PROBS.v, **kwargs) else: logg.warning( f"Unable to plot continuous observations for `{prop!r}`, plotting in discrete mode" ) self._plot_discrete(data, prop, **kwargs) else: raise NotImplementedError( f"Unable to plot property `.{prop}` in discrete mode.")
def create_col_categorical_color(cluster_key: str, rng: np.ndarray) -> np.ndarray: if not is_categorical_dtype(adata.obs[cluster_key]): raise TypeError( f"Expected `adata.obs[{cluster_key!r}]` to be categorical, " f"found `{adata.obs[cluster_key].dtype.name!r}`.") color_key = f"{cluster_key}_colors" if color_key not in adata.uns: logg.warning( f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors" ) colors = _create_categorical_colors( len(adata.obs[cluster_key].cat.categories)) adata.uns[color_key] = colors else: colors = adata.uns[color_key] time_series = adata.obs[time_key] ixs = find_indices(time_series, rng) mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors)) return np.array([ mcolors.to_hex(mapper[v]) for v in adata[ixs, :].obs[cluster_key].values ])
def _check_states_validity(self, n_states: int) -> int: if self._invalid_n_states is not None and n_states in self._invalid_n_states: logg.warning( f"Unable to compute macrostates with `n_states={n_states}` because it will " f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`" ) n_states += 1 # cannot force recomputation of the Schur decomposition assert n_states not in self._invalid_n_states, "Sanity check failed." return n_states
def lineage_drivers( adata: AnnData, lineage: str, backward: bool = False, n_genes: int = 8, ncols: Optional[int] = None, use_raw: bool = False, title_fmt: str = "{gene} qval={qval:.4e}", **kwargs, ) -> None: """ Plot lineage drivers that were uncovered using :func:`cellrank.tl.lineage_drivers`. Parameters ---------- %(adata)s %(backward)s %(plot_lineage_drivers.parameters)s Returns ------- %(just_plots)s """ pk = DummyKernel(adata, backward=backward) mc = GPCCA(pk, read_from_adata=True, write_to_adata=False) if use_raw and adata.raw is None: logg.warning("No raw attribute set. Using `adata.var` instead") use_raw = False direction = DirPrefix.BACKWARD if backward else DirPrefix.FORWARD needle = f"{direction} {lineage} corr" haystack = adata.raw.var if use_raw else adata.var if needle not in haystack: raise RuntimeError( f"Unable to find lineage drivers in " f"`{'adata.raw.var' if use_raw else 'adata.var'}[{needle!r}]`. " f"Compute lineage drivers first as `cellrank.tl.lineage_drivers(lineages={lineage!r}, " f"use_raw={use_raw}, backward={backward}).`") drivers = pd.DataFrame(haystack[[needle, f"{direction} {lineage} qval"]]) drivers.columns = [f"{lineage} corr", f"{lineage} qval"] mc._set(A.LIN_DRIVERS, drivers) mc.plot_lineage_drivers( lineage, n_genes=n_genes, use_raw=use_raw, ncols=ncols, title_fmt=title_fmt, **kwargs, )
def test_formats(self, capsys, logging_state): settings.logfile = sys.stderr settings.verbosity = Verbosity.debug logg.error("0") assert capsys.readouterr().err == "ERROR: 0\n" logg.warning("1") assert capsys.readouterr().err == "WARNING: 1\n" logg.info("2") assert capsys.readouterr().err == "2\n" logg.hint("3") assert capsys.readouterr().err == "--> 3\n"
def _filter_models( models, return_models: bool = False, filter_all_failed: bool = True ) -> Tuple[_return_model_type, _return_model_type, Sequence[str], Sequence[str]]: def is_valid(x: Union[BaseModel, BulkRes]) -> bool: if return_models: assert isinstance( x, BaseModel ), f"Expected `BaseModel`, found `{type(x).__name__!r}`." return bool(x) return (x.x_test is not None and x.y_test is not None and np.all(np.isfinite(x.y_test))) modelmat = pd.DataFrame(models).T modelmask = modelmat.applymap(is_valid) to_keep = modelmask[modelmask.any(axis=1)] to_keep = to_keep.loc[:, to_keep.any(axis=0)].T filtered_models = { gene: { ln: models[gene][ln] for ln in (ln for ln in v.keys() if ( is_valid(models[gene][ln]) if filter_all_failed else True)) } for gene, v in to_keep.to_dict().items() } if not len(filtered_models): if not return_models: raise RuntimeError( "Fitting has failed for all gene/lineage combinations. " "Specify `return_models=True` for more information.") for ms in models.values(): for model in ms.values(): assert isinstance( model, FailedModel ), f"Expected `FailedModel`, found `{type(model).__name__!r}`." model.reraise() if not np.all(modelmask.values): failed_models = modelmat.values[~modelmask.values] logg.warning( f"Unable to fit `{len(failed_models)}` models." + "" if return_models else "Consider specify `return_models=True` for further inspection.") logg.debug("The failed models were:\n`{}`".format("\n".join( f" {m}" for m in failed_models))) # lineages is the max number of lineages return models, filtered_models, tuple(filtered_models.keys()), tuple( to_keep.index)
def _maybe_compute_cond_num(self): if self._compute_cond_num and self._cond_num is None: logg.debug(f"Computing condition number of `{repr(self)}`") self._cond_num = np.linalg.cond(self._transition_matrix.toarray( ) if issparse(self._transition_matrix) else self._transition_matrix ) if self._cond_num > _cond_num_tolerance: logg.warning( f"`{repr(self)}` may be ill-conditioned, its condition number is `{self._cond_num:.2e}`" ) else: logg.info(f"Condition number is `{self._cond_num:.2e}`")
def _remove_zero_rows(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: if a.shape[0] != b.shape[0]: raise ValueError("Lineage objects have unequal cell numbers") bool_a = (a == 0).any(axis=1) bool_b = (b == 0).any(axis=1) mask = ~np.logical_or(bool_a, bool_b) logg.warning( f"Removed {a.shape[0] - np.sum(mask)} rows because they contained zeros" ) return a[mask, :], b[mask, :]
def _check_collection( adata: AnnData, needles: Iterable[str], attr_name: str, key_name: str = "Gene", use_raw: bool = False, raise_exc: bool = True, ) -> List[str]: """ Check if given collection contains all the keys. Parameters ---------- adata: :class:`anndata.AnnData` Annotated data object. needles Keys to check. attr_name Attribute of ``adata`` where the needles are stored. key_name Pretty name of the key which will be displayed when error is found. use_raw Whether to access ``adata.raw`` or just ``adata``. Returns ------- None Nothing, but raises and :class:`KeyError` if one of the needles is not found. """ adata_name = "adata" if use_raw and adata.raw is None: logg.warning( "Argument `use_raw` was set to `True`, but no `raw` attribute is found. Ignoring" ) use_raw = False if use_raw: adata_name = "adata.raw" adata = adata.raw haystack, res = getattr(adata, attr_name), [] for needle in needles: if needle not in haystack: if raise_exc: raise KeyError( f"{key_name} `{needle}` not found in `{adata_name}.{attr_name}`." ) else: res.append(needle) return res
def _get_n_states_from_minchi( self, n_states: Union[Tuple[int, int], List[int], Dict[str, int]]) -> int: if self._gpcca is None: raise RuntimeError( "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`." ) if not isinstance(n_states, (dict, tuple, list)): raise TypeError( f"Expected `n_states` to be either `dict`, `tuple` or a `list`, " f"found `{type(n_states).__name__}`.") if len(n_states) != 2: raise ValueError( f"Expected `n_states` to be of size `2`, found `{len(n_states)}`." ) if isinstance(n_states, dict): if "min" not in n_states or "max" not in n_states: raise KeyError( f"Expected the dictionary to have `'min'` and `'max'` keys, " f"found `{tuple(n_states.keys())}`.") minn, maxx = n_states["min"], n_states["max"] else: minn, maxx = n_states if minn > maxx: logg.debug( f"Swapping minimum and maximum because `{minn}` > `{maxx}`") minn, maxx = maxx, minn if minn <= 1: raise ValueError(f"Minimum value must be > `1`, found `{minn}`.") elif minn == 2: logg.warning( "In most cases, 2 clusters will always be optimal. " "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`" ) minn = 3 if minn >= maxx: maxx = minn + 1 logg.debug( f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`" ) logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`") return int( np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn, maxx))])
def _compute_meta_for_one_state( self, n_cells: int, cluster_key: Optional[str], en_cutoff: Optional[float], p_thresh: float, ) -> None: start = logg.info("Computing metastable states") logg.warning("For `n_states=1`, stationary distribution is computed") eig = self._get(P.EIG) if (eig is not None and "stationary_dist" in eig and eig["params"]["which"] == "LR"): stationary_dist = eig["stationary_dist"] else: self.compute_eigendecomposition(only_evals=False, which="LR") stationary_dist = self._get(P.EIG)["stationary_dist"] self._set_meta_states( memberships=stationary_dist[:, None], n_cells=n_cells, cluster_key=cluster_key, p_thresh=p_thresh, en_cutoff=en_cutoff, ) self._set( A.META_PROBS, Lineage( stationary_dist, names=list(self._get(A.META).cat.categories), colors=self._get(A.META_COLORS), ), ) # reset all the things for key in ( A.ABS_PROBS, A.SCHUR, A.SCHUR_MAT, A.COARSE_T, A.COARSE_STAT_D, A.COARSE_STAT_D, ): self._set(key.s, None) logg.info( f"Adding `.{P.META_PROBS}`\n `.{P.META}`\n Finish", time=start, )
def maybe_create_lineage( direction: Union[str, Direction], pretty_name: Optional[str] = None ): if isinstance(direction, Direction): lin_key = str( AbsProbKey.FORWARD if direction == Direction.FORWARD else AbsProbKey.BACKWARD ) else: lin_key = direction pretty_name = "" if pretty_name is None else (pretty_name + " ") names_key, colors_key = _lin_names(lin_key), _colors(lin_key) if lin_key in adata.obsm.keys(): n_cells, n_lineages = adata.obsm[lin_key].shape logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`") if names_key not in adata.uns.keys(): logg.warning( f" Lineage names not found in `adata.uns[{names_key!r}]`, creating new names" ) names = [f"Lineage {i}" for i in range(n_lineages)] elif len(adata.uns[names_key]) != n_lineages: logg.warning( f" Lineage names are don't have the required length ({n_lineages}), creating new names" ) names = [f"Lineage {i}" for i in range(n_lineages)] else: logg.info(" Successfully loaded names") names = adata.uns[names_key] if colors_key not in adata.uns.keys(): logg.warning( f" Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors" ) colors = _create_categorical_colors(n_lineages) elif len(adata.uns[colors_key]) != n_lineages or not all( map(lambda c: is_color_like(c), adata.uns[colors_key]) ): logg.warning( f" Lineage colors don't have the required length ({n_lineages}) " f"or are not color-like, creating new colors" ) colors = _create_categorical_colors(n_lineages) else: logg.info(" Successfully loaded colors") colors = adata.uns[colors_key] adata.obsm[lin_key] = Lineage( adata.obsm[lin_key], names=names, colors=colors ) adata.uns[colors_key] = colors adata.uns[names_key] = names else: logg.debug( f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`" )
def __init__( self, transition_matrix: Optional[Union[np.ndarray, spmatrix, str]] = None, adata: Optional[AnnData] = None, backward: bool = False, compute_cond_num: bool = False, ): from anndata import AnnData as _AnnData if transition_matrix is None: transition_matrix = _transition( Direction.BACKWARD if backward else Direction.FORWARD) logg.debug( f"Setting transition matrix key to `{transition_matrix!r}`") if isinstance(transition_matrix, str): if adata is None: raise ValueError( "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None." ) transition_matrix = _read_graph_data(adata, transition_matrix) if not isinstance(transition_matrix, (np.ndarray, spmatrix)): raise TypeError( f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, " f"found `{type(transition_matrix).__name__!r}`.") if transition_matrix.shape[0] != transition_matrix.shape[1]: raise ValueError( f"Expected transition matrix to be square, found `{transition_matrix.shape}`." ) if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL): raise ValueError( "Not a valid transition matrix: not all rows sum to 1.") if adata is None: logg.warning("Creating empty `AnnData` object") adata = _AnnData( csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32)) super().__init__(adata, backward=backward, compute_cond_num=compute_cond_num) self._transition_matrix = csr_matrix(transition_matrix) self._maybe_compute_cond_num()
def _(self, data: Lineage, prop: str, discrete: bool = False, **kwargs) -> None: if discrete and kwargs.get("mode", "embedding") == "time": logg.warning( "`mode='time'` is implemented in continuous case, plotting in continuous mode" ) discrete = False if not discrete: self._plot_continuous(data, prop, **kwargs) elif prop == P.ABS_PROBS.v: # for discrete and abs. probs, plot the terminal states prop = P.TERM.v self._plot_discrete(getattr(self, prop, None), prop, **kwargs) else: raise NotImplementedError( f"Unable to plot property `.{prop}` in continuous mode." )
def _read_from_adata( self, conn_key: Optional[str] = "connectivities", read_conn: bool = True, **kwargs: Any, ) -> None: """ Import the base-KNN graph and optionally check for symmetry and connectivity. Parameters ---------- conn_key Key in :attr:`anndata.AnnData.uns` where connectivities are stored. read_conn Whether to read connectivities or set them to `None`. Useful when not exposing density normalization or when KNN connectivities are not used to compute the transition matrix. kwargs Additional keyword arguments. Returns ------- Nothing, just sets :attr:`_conn`. """ if not read_conn: self._conn = None return self._conn = _get_neighs(self.adata, mode="connectivities", key=conn_key).astype(_dtype) check_connectivity = kwargs.pop("check_connectivity", False) if check_connectivity: start = logg.debug("Checking the KNN graph for connectedness") if not _connected(self._conn): logg.warning("KNN graph is not connected", time=start) else: logg.debug("KNN graph is connected", time=start) start = logg.debug("Checking the KNN graph for symmetry") if not _symmetric(self._conn): logg.warning("KNN graph is not symmetric", time=start) else: logg.debug("KNN graph is symmetric", time=start)
def _ensure_numeric_ordered(adata: AnnData, key: str) -> pd.Series: if key not in adata.obs.keys(): raise KeyError(f"Unable to find data in `adata.obs[{key!r}]`.") exp_time = adata.obs[key].copy() if not is_numeric_dtype(np.asarray(exp_time)): try: exp_time = np.asarray(exp_time).astype(float) except ValueError as e: raise TypeError( f"Unable to convert `adata.obs[{key!r}]` of type `{infer_dtype(adata.obs[key])}` to `float`." ) from e if not is_categorical_dtype(exp_time): logg.debug(f"Converting `adata.obs[{key!r}]` to `categorical`") exp_time = np.asarray(exp_time) categories = sorted(set(exp_time[~np.isnan(exp_time)])) if len(categories) > 100: raise ValueError( f"Unable to convert `adata.obs[{key!r}]` to `categorical` since it " f"would create `{len(categories)}` categories." ) exp_time = pd.Series( pd.Categorical( exp_time, categories=categories, ordered=True, ) ) if not exp_time.cat.ordered: logg.warning("Categories are not ordered. Using ascending order") exp_time.cat = exp_time.cat.as_ordered() exp_time = pd.Series(pd.Categorical(exp_time, ordered=True), index=adata.obs_names) if exp_time.isnull().any(): raise ValueError("Series contains NaN value(s).") n_cats = len(exp_time.cat.categories) if n_cats < 2: raise ValueError(f"Expected to find at least `2` categories, found `{n_cats}`.") return exp_time
def _invert_matrix(mat, use_petsc: bool = True, **kwargs) -> np.ndarray: if use_petsc: try: import petsc4py # noqa except ImportError: global _PETSC_ERROR_MSG_SHOWN if not _PETSC_ERROR_MSG_SHOWN: _PETSC_ERROR_MSG_SHOWN = True logg.warning(_PETSC_ERROR_MSG.format(_DEFAULT_SOLVER)) kwargs["solver"] = _DEFAULT_SOLVER use_petsc = False if use_petsc: return _solve_lin_system(mat, speye(mat.shape[0]), use_petsc=True, **kwargs) return sinv(mat).toarray() if issparse(mat) else np.linalg.inv(mat)
def _write_terminal_states(self, time=None) -> None: super()._write_terminal_states(time=time) term_abs_probs = self._get(A.TERM_ABS_PROBS) if term_abs_probs is None: # possibly remove previous value if it's inconsistent term_abs_probs = self.adata.obsm.get(self._term_abs_prob_key, None) if term_abs_probs is not None: new = list(self._get(P.TERM).cat.categories) old = list(term_abs_probs.names) if term_abs_probs.shape[1] == len(new) and new == old: self.adata.obsm[self._term_abs_prob_key] = term_abs_probs else: logg.warning( f"Removing previously computed `adata.obsm[{self._term_abs_prob_key!r}]` because the " f"names mismatch `{new}` (new), `{old}` (old).") self._set(A.TERM_ABS_PROBS, None) self.adata.obsm.pop(self._term_abs_prob_key, None)
def _read_from_adata(self, **kwargs): """Import the base-KNN graph and optionally check for symmetry and connectivity.""" if not _has_neighs(self.adata): raise KeyError("Compute KNN graph first as `scanpy.pp.neighbors()`.") self._conn = _get_neighs(self.adata, "connectivities").astype(_dtype) check_connectivity = kwargs.pop("check_connectivity", False) if check_connectivity: start = logg.debug("Checking the KNN graph for connectedness") if not _connected(self._conn): logg.warning("KNN graph is not connected", time=start) else: logg.debug("KNN graph is connected", time=start) start = logg.debug("Checking the KNN graph for symmetry") if not _symmetric(self._conn): logg.warning("KNN graph is not symmetric", time=start) else: logg.debug("KNN graph is symmetric", time=start)
def compute_partition(self) -> None: """ Compute communication classes for the Markov chain. Returns ------- None Nothing, but updates the following fields: - :attr:`recurrent_classes` - :attr:`transient_classes` - :attr:`is_irreducible` """ start = logg.info("Computing communication classes") n_states = len(self) rec_classes, trans_classes = _partition(self.transition_matrix) self._is_irreducible = len(rec_classes) == 1 and len( trans_classes) == 0 if not self._is_irreducible: self._trans_classes = _make_cat(trans_classes, n_states, self.adata.obs_names) self._rec_classes = _make_cat(rec_classes, n_states, self.adata.obs_names) logg.info( f"Found `{(len(rec_classes))}` recurrent and `{len(trans_classes)}` transient classes\n" f"Adding `.recurrent_classes`\n" f" `.transient_classes`\n" f" `.is_irreducible`\n" f" Finish", time=start, ) else: logg.warning( "The transition matrix is irreducible, cannot further partition it\n Finish", time=start, )
def write(self, fname: Union[str, Path], ext: Optional[str] = "pickle") -> None: """ %(pickleable.full_desc)s Parameters ---------- %(pickleable.parameters)s Returns ------- %(pickleable.returns)s """ # noqa fname = str(fname) if ext is not None: if not ext.startswith("."): ext = "." + ext if not fname.endswith(ext): fname += ext logg.debug(f"Writing to `{fname}`") with open(fname, "wb") as fout: if version_info[:2] > (3, 6): pickle.dump(self, fout) else: # we need to use PrecomputedKernel because Python3.6 can't pickle Enums # and they are present in VelocityKernel logg.warning( "Saving kernel as `cellrank.tl.kernels.PrecomputedKernel`") orig_kernel = self.kernel self._kernel = PrecomputedKernel(self.kernel) try: pickle.dump(self, fout) except Exception as e: raise e finally: self._kernel = orig_kernel
def _write_graph_data( adata: AnnData, data: Union[np.ndarray, spmatrix], key: str, ): """ Write graph data to :mod:`AnnData`. :module`anndata` >=0.7 stores `(n_obs x n_obs)` matrices in `.obsp` rather than `.uns`. This is for backward compatibility. Parameters ---------- adata Annotated data object. data The graph data we want to write. key Key from either ``adata.uns`` or `adata.obsp``. Returns -------- None Nothing, just writes the data. """ try: adata.obsp[key] = data write_to = "obsp" if data.shape[0] != data.shape[1]: logg.warning( f"`adata.obsp` attribute should only contain square matrices, found shape `{data.shape}`" ) except AttributeError: adata.uns[key] = data write_to = "uns" logg.debug(f"Writing graph data to `adata.{write_to}[{key!r}]`")
def _compute_transition_matrix( self, matrix: spmatrix, density_normalize: bool = True, ): # density correction based on node degrees in the KNN graph matrix = csr_matrix(matrix) if not isspmatrix_csr(matrix) else matrix if density_normalize: matrix = self._density_normalize(matrix) # check for zero-rows problematic_indices = np.where( np.array(matrix.sum(1)).flatten() == 0)[0] if len(problematic_indices): logg.warning( f"Detected `{len(problematic_indices)}` absorbing states in the transition matrix. " f"This matrix won't be reducible") matrix[problematic_indices, problematic_indices] = 1.0 # setting this property automatically row-normalizes self.transition_matrix = matrix self._maybe_compute_cond_num()
def _compute_transition_matrix( self, matrix: Union[np.ndarray, spmatrix], density_normalize: bool = True, check_irreducibility: bool = False, ): if matrix.shape[0] != matrix.shape[1]: raise ValueError( f"Expected a square matrix, found `{matrix.shape}`.") if matrix.shape[0] != self.adata.n_obs: raise ValueError( f"Expected matrix to be of shape `{(self.adata.n_obs, self.adata.n_obs)}`, " f"found `{matrix.shape}`.") matrix = matrix.astype(_dtype) if issparse(matrix) and not isspmatrix_csr(matrix): matrix = csr_matrix(matrix) # density correction based on node degrees in the KNN graph if density_normalize: matrix = self._density_normalize(matrix) # check for zero-rows problematic_indices = np.where( np.array(matrix.sum(1)).flatten() == 0)[0] if len(problematic_indices): logg.warning( f"Detected `{len(problematic_indices)}` absorbing states in the transition matrix. " f"This matrix won't be irreducible") matrix[problematic_indices, problematic_indices] = 1.0 if check_irreducibility: _irreducible(matrix) # setting this property automatically row-normalizes self.transition_matrix = matrix self._maybe_compute_cond_num()
def __init__( self, adata: AnnData, n_splines: Optional[int] = 10, spline_order: int = 3, distribution: str = "gamma", link: str = "log", max_iter: int = 2000, expectile: Optional[float] = None, use_default_conf_int: bool = False, grid: Optional[Mapping] = None, spline_kwargs: Mapping = MappingProxyType({}), **kwargs, ): term = s( 0, spline_order=spline_order, n_splines=n_splines, penalties=["derivative", "l2"], **_filter_kwargs(s, **{**{"lam": 3}, **spline_kwargs}), ) link = GamLinkFunction(link) distribution = GamDistribution(distribution) if distribution == GamDistribution.GAUSS: distribution = GamDistribution.NORMAL if expectile is not None: if not (0 < expectile < 1): raise ValueError( f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`." ) if distribution != "normal" or link != "identity": logg.warning( f"Expectile GAM works only with `normal` distribution and `identity` link function," f"found `{distribution!r}` distribution and {link!r} link functions." ) model = ExpectileGAM( term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs ) else: gam = _gams[ distribution, link ] # doing it like this ensure that user can specify scale kwargs["link"] = link.s kwargs["distribution"] = distribution.s model = gam( term, max_iter=max_iter, verbose=False, **_filter_kwargs(gam.__init__, **kwargs), ) super().__init__(adata, model=model) self._use_default_conf_int = use_default_conf_int if grid is None: self._grid = None elif isinstance(grid, dict): self._grid = _copy(grid) elif isinstance(grid, str): self._grid = object() if grid == "default" else None else: raise TypeError( f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`." )