Exemple #1
0
        def create_ixs(ixs: Indices_t, *, kind: str) -> Optional[np.ndarray]:
            if ixs is None:
                return None
            if isinstance(ixs, dict):
                # fmt: off
                if len(ixs) != 1:
                    raise ValueError(
                        f"Expected to find only 1 cluster key, found `{len(ixs)}`."
                    )
                cluster_key = next(iter(ixs.keys()))
                if cluster_key not in self.adata.obs:
                    raise KeyError(
                        f"Unable to find `adata.obs[{cluster_key!r}]`.")
                if not is_categorical_dtype(self.adata.obs[cluster_key]):
                    raise TypeError(
                        f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                        f"found `{infer_dtype(self.adata.obs[cluster_key])}`.")
                ixs = np.where(
                    np.isin(self.adata.obs[cluster_key], ixs[cluster_key]))[0]
                # fmt: on
            elif isinstance(ixs, str):
                ixs = np.where(self.adata.obs_names == ixs)[0]
            else:
                ixs = np.where(np.isin(self.adata.obs_names, ixs))[0]

            if not len(ixs):
                logg.warning(
                    f"No {kind} indices have been selected, using `None`")
                return None

            return ixs
    def _reconstruct_lineage(self, attr: PrettyEnum, obsm_key: str):
        self._set_or_debug(obsm_key, self.adata.obsm, attr)
        names = self._set_or_debug(_lin_names(self._term_key), self.adata.uns)
        colors = self._set_or_debug(_colors(self._term_key), self.adata.uns)

        probs = self._get(attr)

        if probs is not None:
            if len(names) != probs.shape[1]:
                if isinstance(probs, Lineage):
                    names = probs.names
                else:
                    logg.warning(
                        f"Expected lineage names to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new names"
                    )
                    names = [f"Lineage {i}" for i in range(probs.shape[1])]
            if len(colors) != probs.shape[1] or not all(
                map(lambda c: isinstance(c, str) and is_color_like(c), colors)
            ):
                if isinstance(probs, Lineage):
                    colors = probs.colors
                else:
                    logg.warning(
                        f"Expected lineage colors to be of length `{probs.shape[1]}`, found `{len(names)}`. "
                        f"Creating new colors"
                    )
                    colors = _create_categorical_colors(probs.shape[1])
            self._set(attr, Lineage(probs, names=names, colors=colors))

            self.adata.obsm[obsm_key] = self._get(attr)
            self.adata.uns[_lin_names(self._term_key)] = names
            self.adata.uns[_colors(self._term_key)] = colors
    def _detect_cc_stages(self, rc_labels: Series, p_thresh: float = 1e-15) -> None:
        """
        Detect cell-cycle driven start or endpoints.

        Parameters
        ----------
        rc_labels
            Approximate recurrent classes.
        p_thresh
            P-value threshold for the rank-sum test for the group to be considered cell-cycle driven.

        Returns
        -------
        None
            Nothing, but warns if a group is cell-cycle driven.
        """

        # initialize the groups (start or end clusters) and scores
        groups = rc_labels.cat.categories
        scores = []
        if self._G2M_score is not None:
            scores.append(self._G2M_score)
        if self._S_score is not None:
            scores.append(self._S_score)

        # loop over groups and scores
        for group in groups:
            mask = rc_labels == group
            for score in scores:
                a, b = score[mask], score[~mask]
                statistic, pvalue = ranksums(a, b)
                if statistic > 0 and pvalue < p_thresh:
                    logg.warning(f"Group `{group!r}` appears to be cell-cycle driven")
                    break
Exemple #4
0
    def __init__(
        self,
        obj: Union[AnnData, np.ndarray, spmatrix, KernelExpression],
        key: Optional[str] = None,
        obsp_key: Optional[str] = None,
        write_to_adata: bool = True,
    ):
        if isinstance(obj, KernelExpression):
            self._kernel = obj
        elif isinstance(obj, (np.ndarray, spmatrix)):
            self._kernel = PrecomputedKernel(obj)
        elif isinstance(obj, AnnData):
            if obsp_key is None:
                raise ValueError(
                    "Specify `obsp_key=...` when supplying an `AnnData` object."
                )
            elif obsp_key not in obj.obsp.keys():
                raise KeyError(
                    f"Key `{obsp_key!r}` not found in `adata.obsp`.")
            self._kernel = PrecomputedKernel(obj.obsp[obsp_key], adata=obj)
        else:
            raise TypeError(
                f"Expected an object of type `KernelExpression`, `numpy.ndarray`, `scipy.sparse.spmatrix` "
                f"or `anndata.AnnData`, got `{type(obj).__name__!r}`.")

        if self.kernel._transition_matrix is None:
            # access the private attribute to avoid accidentally computing the transition matrix
            # in principle, it doesn't make a difference, apart from not seeing the message
            logg.warning(
                "Computing transition matrix using the default parameters")
            self.kernel.compute_transition_matrix()

        if write_to_adata:
            self.kernel.write_to_adata(key=key)
Exemple #5
0
    def _fit_terminal_states(
        self,
        n_lineages: Optional[int] = None,
        cluster_key: Optional[str] = None,
        method: str = "krylov",
        **kwargs,
    ) -> None:
        if n_lineages is None or n_lineages == 1:
            self.compute_eigendecomposition()
            if n_lineages is None:
                n_lineages = self.eigendecomposition["eigengap"] + 1

        if n_lineages > 1:
            self.compute_schur(n_lineages, method=method)

        try:
            self.compute_macrostates(n_states=n_lineages,
                                     cluster_key=cluster_key,
                                     **kwargs)
        except ValueError:
            logg.warning(
                f"Computing `{n_lineages}` macrostates cuts through a block of complex conjugates. "
                f"Increasing `n_lineages` to {n_lineages + 1}")
            self.compute_macrostates(n_states=n_lineages + 1,
                                     cluster_key=cluster_key,
                                     **kwargs)

        fs_kwargs = {
            "n_cells": kwargs["n_cells"]
        } if "n_cells" in kwargs else {}

        if n_lineages is None:
            self.compute_terminal_states(method="eigengap", **fs_kwargs)
        else:
            self.set_terminal_states_from_macrostates(**fs_kwargs)
Exemple #6
0
    def _(self,
          data: pd.Series,
          prop: str,
          discrete: bool = False,
          **kwargs) -> None:
        if discrete and kwargs.get("mode", "embedding") == "time":
            logg.warning(
                "`mode='time'` is implemented in continuous case, plotting in continuous mode"
            )
            discrete = False

        if discrete:
            self._plot_discrete(data, prop, **kwargs)
        elif prop == P.MACRO.v:  # GPCCA
            prop = P.MACRO_MEMBER.v
            self._plot_continuous(getattr(self, prop, None), prop, **kwargs)
        elif prop == P.TERM.v:
            probs = getattr(self, A.TERM_ABS_PROBS.s, None)
            # we have this only in GPCCA
            if isinstance(probs, Lineage):
                self._plot_continuous(probs, P.TERM_PROBS.v, **kwargs)
            else:
                logg.warning(
                    f"Unable to plot continuous observations for `{prop!r}`, plotting in discrete mode"
                )
                self._plot_discrete(data, prop, **kwargs)
        else:
            raise NotImplementedError(
                f"Unable to plot property `.{prop}` in discrete mode.")
Exemple #7
0
    def create_col_categorical_color(cluster_key: str,
                                     rng: np.ndarray) -> np.ndarray:
        if not is_categorical_dtype(adata.obs[cluster_key]):
            raise TypeError(
                f"Expected `adata.obs[{cluster_key!r}]` to be categorical, "
                f"found `{adata.obs[cluster_key].dtype.name!r}`.")

        color_key = f"{cluster_key}_colors"
        if color_key not in adata.uns:
            logg.warning(
                f"Color key `{color_key!r}` not found in `adata.uns`. Creating new colors"
            )
            colors = _create_categorical_colors(
                len(adata.obs[cluster_key].cat.categories))
            adata.uns[color_key] = colors
        else:
            colors = adata.uns[color_key]

        time_series = adata.obs[time_key]
        ixs = find_indices(time_series, rng)

        mapper = dict(zip(adata.obs[cluster_key].cat.categories, colors))

        return np.array([
            mcolors.to_hex(mapper[v])
            for v in adata[ixs, :].obs[cluster_key].values
        ])
Exemple #8
0
    def _check_states_validity(self, n_states: int) -> int:
        if self._invalid_n_states is not None and n_states in self._invalid_n_states:
            logg.warning(
                f"Unable to compute macrostates with `n_states={n_states}` because it will "
                f"split the conjugate eigenvalues. Increasing `n_states` to `{n_states + 1}`"
            )
            n_states += 1  # cannot force recomputation of the Schur decomposition
            assert n_states not in self._invalid_n_states, "Sanity check failed."

        return n_states
Exemple #9
0
def lineage_drivers(
    adata: AnnData,
    lineage: str,
    backward: bool = False,
    n_genes: int = 8,
    ncols: Optional[int] = None,
    use_raw: bool = False,
    title_fmt: str = "{gene} qval={qval:.4e}",
    **kwargs,
) -> None:
    """
    Plot lineage drivers that were uncovered using :func:`cellrank.tl.lineage_drivers`.

    Parameters
    ----------
    %(adata)s
    %(backward)s
    %(plot_lineage_drivers.parameters)s

    Returns
    -------
    %(just_plots)s
    """

    pk = DummyKernel(adata, backward=backward)
    mc = GPCCA(pk, read_from_adata=True, write_to_adata=False)

    if use_raw and adata.raw is None:
        logg.warning("No raw attribute set. Using `adata.var` instead")
        use_raw = False

    direction = DirPrefix.BACKWARD if backward else DirPrefix.FORWARD
    needle = f"{direction} {lineage} corr"

    haystack = adata.raw.var if use_raw else adata.var

    if needle not in haystack:
        raise RuntimeError(
            f"Unable to find lineage drivers in "
            f"`{'adata.raw.var' if use_raw else 'adata.var'}[{needle!r}]`. "
            f"Compute lineage drivers first as `cellrank.tl.lineage_drivers(lineages={lineage!r}, "
            f"use_raw={use_raw}, backward={backward}).`")

    drivers = pd.DataFrame(haystack[[needle, f"{direction} {lineage} qval"]])
    drivers.columns = [f"{lineage} corr", f"{lineage} qval"]
    mc._set(A.LIN_DRIVERS, drivers)

    mc.plot_lineage_drivers(
        lineage,
        n_genes=n_genes,
        use_raw=use_raw,
        ncols=ncols,
        title_fmt=title_fmt,
        **kwargs,
    )
Exemple #10
0
 def test_formats(self, capsys, logging_state):
     settings.logfile = sys.stderr
     settings.verbosity = Verbosity.debug
     logg.error("0")
     assert capsys.readouterr().err == "ERROR: 0\n"
     logg.warning("1")
     assert capsys.readouterr().err == "WARNING: 1\n"
     logg.info("2")
     assert capsys.readouterr().err == "2\n"
     logg.hint("3")
     assert capsys.readouterr().err == "--> 3\n"
Exemple #11
0
def _filter_models(
    models,
    return_models: bool = False,
    filter_all_failed: bool = True
) -> Tuple[_return_model_type, _return_model_type, Sequence[str],
           Sequence[str]]:
    def is_valid(x: Union[BaseModel, BulkRes]) -> bool:
        if return_models:
            assert isinstance(
                x, BaseModel
            ), f"Expected `BaseModel`, found `{type(x).__name__!r}`."
            return bool(x)

        return (x.x_test is not None and x.y_test is not None
                and np.all(np.isfinite(x.y_test)))

    modelmat = pd.DataFrame(models).T
    modelmask = modelmat.applymap(is_valid)
    to_keep = modelmask[modelmask.any(axis=1)]
    to_keep = to_keep.loc[:, to_keep.any(axis=0)].T

    filtered_models = {
        gene: {
            ln: models[gene][ln]
            for ln in (ln for ln in v.keys() if (
                is_valid(models[gene][ln]) if filter_all_failed else True))
        }
        for gene, v in to_keep.to_dict().items()
    }

    if not len(filtered_models):
        if not return_models:
            raise RuntimeError(
                "Fitting has failed for all gene/lineage combinations. "
                "Specify `return_models=True` for more information.")
        for ms in models.values():
            for model in ms.values():
                assert isinstance(
                    model, FailedModel
                ), f"Expected `FailedModel`, found `{type(model).__name__!r}`."
                model.reraise()

    if not np.all(modelmask.values):
        failed_models = modelmat.values[~modelmask.values]
        logg.warning(
            f"Unable to fit `{len(failed_models)}` models." +
            "" if return_models else
            "Consider specify `return_models=True` for further inspection.")
        logg.debug("The failed models were:\n`{}`".format("\n".join(
            f"    {m}" for m in failed_models)))

    # lineages is the max number of lineages
    return models, filtered_models, tuple(filtered_models.keys()), tuple(
        to_keep.index)
Exemple #12
0
 def _maybe_compute_cond_num(self):
     if self._compute_cond_num and self._cond_num is None:
         logg.debug(f"Computing condition number of `{repr(self)}`")
         self._cond_num = np.linalg.cond(self._transition_matrix.toarray(
         ) if issparse(self._transition_matrix) else self._transition_matrix
                                         )
         if self._cond_num > _cond_num_tolerance:
             logg.warning(
                 f"`{repr(self)}` may be ill-conditioned, its condition number is `{self._cond_num:.2e}`"
             )
         else:
             logg.info(f"Condition number is `{self._cond_num:.2e}`")
Exemple #13
0
def _remove_zero_rows(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    if a.shape[0] != b.shape[0]:
        raise ValueError("Lineage objects have unequal cell numbers")

    bool_a = (a == 0).any(axis=1)
    bool_b = (b == 0).any(axis=1)
    mask = ~np.logical_or(bool_a, bool_b)

    logg.warning(
        f"Removed {a.shape[0] - np.sum(mask)} rows because they contained zeros"
    )

    return a[mask, :], b[mask, :]
Exemple #14
0
def _check_collection(
    adata: AnnData,
    needles: Iterable[str],
    attr_name: str,
    key_name: str = "Gene",
    use_raw: bool = False,
    raise_exc: bool = True,
) -> List[str]:
    """
    Check if given collection contains all the keys.

    Parameters
    ----------
    adata: :class:`anndata.AnnData`
        Annotated data object.
    needles
        Keys to check.
    attr_name
        Attribute of ``adata`` where the needles are stored.
    key_name
        Pretty name of the key which will be displayed when error is found.
    use_raw
        Whether to access ``adata.raw`` or just ``adata``.

    Returns
    -------
    None
        Nothing, but raises and :class:`KeyError` if one of the needles is not found.
    """
    adata_name = "adata"

    if use_raw and adata.raw is None:
        logg.warning(
            "Argument `use_raw` was set to `True`, but no `raw` attribute is found. Ignoring"
        )
        use_raw = False
    if use_raw:
        adata_name = "adata.raw"
        adata = adata.raw

    haystack, res = getattr(adata, attr_name), []
    for needle in needles:
        if needle not in haystack:
            if raise_exc:
                raise KeyError(
                    f"{key_name} `{needle}` not found in `{adata_name}.{attr_name}`."
                )
        else:
            res.append(needle)

    return res
Exemple #15
0
    def _get_n_states_from_minchi(
            self, n_states: Union[Tuple[int, int], List[int],
                                  Dict[str, int]]) -> int:
        if self._gpcca is None:
            raise RuntimeError(
                "Compute Schur decomposition first as `.compute_schur()` when `use_min_chi=True`."
            )

        if not isinstance(n_states, (dict, tuple, list)):
            raise TypeError(
                f"Expected `n_states` to be either `dict`, `tuple` or a `list`, "
                f"found `{type(n_states).__name__}`.")
        if len(n_states) != 2:
            raise ValueError(
                f"Expected `n_states` to be of size `2`, found `{len(n_states)}`."
            )

        if isinstance(n_states, dict):
            if "min" not in n_states or "max" not in n_states:
                raise KeyError(
                    f"Expected the dictionary to have `'min'` and `'max'` keys, "
                    f"found `{tuple(n_states.keys())}`.")
            minn, maxx = n_states["min"], n_states["max"]
        else:
            minn, maxx = n_states

        if minn > maxx:
            logg.debug(
                f"Swapping minimum and maximum because `{minn}` > `{maxx}`")
            minn, maxx = maxx, minn

        if minn <= 1:
            raise ValueError(f"Minimum value must be > `1`, found `{minn}`.")
        elif minn == 2:
            logg.warning(
                "In most cases, 2 clusters will always be optimal. "
                "If you really expect 2 clusters, use `n_states=2` and `use_minchi=False`. Setting minimum to `3`"
            )
            minn = 3

        if minn >= maxx:
            maxx = minn + 1
            logg.debug(
                f"Setting maximum to `{maxx}` because it was <= than minimum `{minn}`"
            )

        logg.info(f"Calculating minChi within interval `[{minn}, {maxx}]`")

        return int(
            np.arange(minn, maxx + 1)[np.argmax(self._gpcca.minChi(minn,
                                                                   maxx))])
Exemple #16
0
    def _compute_meta_for_one_state(
        self,
        n_cells: int,
        cluster_key: Optional[str],
        en_cutoff: Optional[float],
        p_thresh: float,
    ) -> None:
        start = logg.info("Computing metastable states")
        logg.warning("For `n_states=1`, stationary distribution is computed")

        eig = self._get(P.EIG)
        if (eig is not None and "stationary_dist" in eig
                and eig["params"]["which"] == "LR"):
            stationary_dist = eig["stationary_dist"]
        else:
            self.compute_eigendecomposition(only_evals=False, which="LR")
            stationary_dist = self._get(P.EIG)["stationary_dist"]

        self._set_meta_states(
            memberships=stationary_dist[:, None],
            n_cells=n_cells,
            cluster_key=cluster_key,
            p_thresh=p_thresh,
            en_cutoff=en_cutoff,
        )
        self._set(
            A.META_PROBS,
            Lineage(
                stationary_dist,
                names=list(self._get(A.META).cat.categories),
                colors=self._get(A.META_COLORS),
            ),
        )

        # reset all the things
        for key in (
                A.ABS_PROBS,
                A.SCHUR,
                A.SCHUR_MAT,
                A.COARSE_T,
                A.COARSE_STAT_D,
                A.COARSE_STAT_D,
        ):
            self._set(key.s, None)

        logg.info(
            f"Adding `.{P.META_PROBS}`\n        `.{P.META}`\n    Finish",
            time=start,
        )
Exemple #17
0
    def maybe_create_lineage(
        direction: Union[str, Direction], pretty_name: Optional[str] = None
    ):
        if isinstance(direction, Direction):
            lin_key = str(
                AbsProbKey.FORWARD
                if direction == Direction.FORWARD
                else AbsProbKey.BACKWARD
            )
        else:
            lin_key = direction

        pretty_name = "" if pretty_name is None else (pretty_name + " ")
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)

        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(f"Creating {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`")

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage names not found in `adata.uns[{names_key!r}]`, creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"    Lineage names are don't have the required length ({n_lineages}), creating new names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("    Successfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"    Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                map(lambda c: is_color_like(c), adata.uns[colors_key])
            ):
                logg.warning(
                    f"    Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("    Successfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(
                adata.obsm[lin_key], names=names, colors=colors
            )
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"Unable to load {pretty_name}`Lineage` from `adata.obsm[{lin_key!r}]`"
            )
    def __init__(
        self,
        transition_matrix: Optional[Union[np.ndarray, spmatrix, str]] = None,
        adata: Optional[AnnData] = None,
        backward: bool = False,
        compute_cond_num: bool = False,
    ):
        from anndata import AnnData as _AnnData

        if transition_matrix is None:
            transition_matrix = _transition(
                Direction.BACKWARD if backward else Direction.FORWARD)
            logg.debug(
                f"Setting transition matrix key to `{transition_matrix!r}`")

        if isinstance(transition_matrix, str):
            if adata is None:
                raise ValueError(
                    "When `transition_matrix` specifies a key to `adata.obsp`, `adata` cannot be None."
                )
            transition_matrix = _read_graph_data(adata, transition_matrix)

        if not isinstance(transition_matrix, (np.ndarray, spmatrix)):
            raise TypeError(
                f"Expected transition matrix to be of type `numpy.ndarray` or `scipy.sparse.spmatrix`, "
                f"found `{type(transition_matrix).__name__!r}`.")

        if transition_matrix.shape[0] != transition_matrix.shape[1]:
            raise ValueError(
                f"Expected transition matrix to be square, found `{transition_matrix.shape}`."
            )

        if not np.allclose(np.sum(transition_matrix, axis=1), 1.0, rtol=_RTOL):
            raise ValueError(
                "Not a valid transition matrix: not all rows sum to 1.")

        if adata is None:
            logg.warning("Creating empty `AnnData` object")
            adata = _AnnData(
                csr_matrix((transition_matrix.shape[0], 1), dtype=np.float32))

        super().__init__(adata,
                         backward=backward,
                         compute_cond_num=compute_cond_num)
        self._transition_matrix = csr_matrix(transition_matrix)
        self._maybe_compute_cond_num()
Exemple #19
0
    def _(self, data: Lineage, prop: str, discrete: bool = False, **kwargs) -> None:
        if discrete and kwargs.get("mode", "embedding") == "time":
            logg.warning(
                "`mode='time'` is implemented in continuous case, plotting in continuous mode"
            )
            discrete = False

        if not discrete:
            self._plot_continuous(data, prop, **kwargs)
        elif prop == P.ABS_PROBS.v:
            # for discrete and abs. probs, plot the terminal states
            prop = P.TERM.v
            self._plot_discrete(getattr(self, prop, None), prop, **kwargs)
        else:
            raise NotImplementedError(
                f"Unable to plot property `.{prop}` in continuous mode."
            )
Exemple #20
0
    def _read_from_adata(
        self,
        conn_key: Optional[str] = "connectivities",
        read_conn: bool = True,
        **kwargs: Any,
    ) -> None:
        """
        Import the base-KNN graph and optionally check for symmetry and connectivity.

        Parameters
        ----------
        conn_key
            Key in :attr:`anndata.AnnData.uns` where connectivities are stored.
        read_conn
            Whether to read connectivities or set them to `None`. Useful when not exposing density normalization or
            when KNN connectivities are not used to compute the transition matrix.
        kwargs
            Additional keyword arguments.

        Returns
        -------
        Nothing, just sets :attr:`_conn`.
        """
        if not read_conn:
            self._conn = None
            return

        self._conn = _get_neighs(self.adata,
                                 mode="connectivities",
                                 key=conn_key).astype(_dtype)

        check_connectivity = kwargs.pop("check_connectivity", False)
        if check_connectivity:
            start = logg.debug("Checking the KNN graph for connectedness")
            if not _connected(self._conn):
                logg.warning("KNN graph is not connected", time=start)
            else:
                logg.debug("KNN graph is connected", time=start)

        start = logg.debug("Checking the KNN graph for symmetry")
        if not _symmetric(self._conn):
            logg.warning("KNN graph is not symmetric", time=start)
        else:
            logg.debug("KNN graph is symmetric", time=start)
Exemple #21
0
def _ensure_numeric_ordered(adata: AnnData, key: str) -> pd.Series:
    if key not in adata.obs.keys():
        raise KeyError(f"Unable to find data in `adata.obs[{key!r}]`.")

    exp_time = adata.obs[key].copy()
    if not is_numeric_dtype(np.asarray(exp_time)):
        try:
            exp_time = np.asarray(exp_time).astype(float)
        except ValueError as e:
            raise TypeError(
                f"Unable to convert `adata.obs[{key!r}]` of type `{infer_dtype(adata.obs[key])}` to `float`."
            ) from e

    if not is_categorical_dtype(exp_time):
        logg.debug(f"Converting `adata.obs[{key!r}]` to `categorical`")
        exp_time = np.asarray(exp_time)
        categories = sorted(set(exp_time[~np.isnan(exp_time)]))
        if len(categories) > 100:
            raise ValueError(
                f"Unable to convert `adata.obs[{key!r}]` to `categorical` since it "
                f"would create `{len(categories)}` categories."
            )
        exp_time = pd.Series(
            pd.Categorical(
                exp_time,
                categories=categories,
                ordered=True,
            )
        )

    if not exp_time.cat.ordered:
        logg.warning("Categories are not ordered. Using ascending order")
        exp_time.cat = exp_time.cat.as_ordered()

    exp_time = pd.Series(pd.Categorical(exp_time, ordered=True), index=adata.obs_names)
    if exp_time.isnull().any():
        raise ValueError("Series contains NaN value(s).")

    n_cats = len(exp_time.cat.categories)
    if n_cats < 2:
        raise ValueError(f"Expected to find at least `2` categories, found `{n_cats}`.")

    return exp_time
Exemple #22
0
def _invert_matrix(mat, use_petsc: bool = True, **kwargs) -> np.ndarray:
    if use_petsc:
        try:
            import petsc4py  # noqa
        except ImportError:
            global _PETSC_ERROR_MSG_SHOWN
            if not _PETSC_ERROR_MSG_SHOWN:
                _PETSC_ERROR_MSG_SHOWN = True
                logg.warning(_PETSC_ERROR_MSG.format(_DEFAULT_SOLVER))
            kwargs["solver"] = _DEFAULT_SOLVER
            use_petsc = False

    if use_petsc:
        return _solve_lin_system(mat,
                                 speye(mat.shape[0]),
                                 use_petsc=True,
                                 **kwargs)

    return sinv(mat).toarray() if issparse(mat) else np.linalg.inv(mat)
Exemple #23
0
    def _write_terminal_states(self, time=None) -> None:
        super()._write_terminal_states(time=time)

        term_abs_probs = self._get(A.TERM_ABS_PROBS)
        if term_abs_probs is None:
            # possibly remove previous value if it's inconsistent
            term_abs_probs = self.adata.obsm.get(self._term_abs_prob_key, None)

        if term_abs_probs is not None:
            new = list(self._get(P.TERM).cat.categories)
            old = list(term_abs_probs.names)
            if term_abs_probs.shape[1] == len(new) and new == old:
                self.adata.obsm[self._term_abs_prob_key] = term_abs_probs
            else:
                logg.warning(
                    f"Removing previously computed `adata.obsm[{self._term_abs_prob_key!r}]` because the "
                    f"names mismatch `{new}` (new), `{old}` (old).")

                self._set(A.TERM_ABS_PROBS, None)
                self.adata.obsm.pop(self._term_abs_prob_key, None)
Exemple #24
0
    def _read_from_adata(self, **kwargs):
        """Import the base-KNN graph and optionally check for symmetry and connectivity."""

        if not _has_neighs(self.adata):
            raise KeyError("Compute KNN graph first as `scanpy.pp.neighbors()`.")

        self._conn = _get_neighs(self.adata, "connectivities").astype(_dtype)

        check_connectivity = kwargs.pop("check_connectivity", False)
        if check_connectivity:
            start = logg.debug("Checking the KNN graph for connectedness")
            if not _connected(self._conn):
                logg.warning("KNN graph is not connected", time=start)
            else:
                logg.debug("KNN graph is connected", time=start)

        start = logg.debug("Checking the KNN graph for symmetry")
        if not _symmetric(self._conn):
            logg.warning("KNN graph is not symmetric", time=start)
        else:
            logg.debug("KNN graph is symmetric", time=start)
Exemple #25
0
    def compute_partition(self) -> None:
        """
        Compute communication classes for the Markov chain.

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`recurrent_classes`
                - :attr:`transient_classes`
                - :attr:`is_irreducible`
        """

        start = logg.info("Computing communication classes")
        n_states = len(self)

        rec_classes, trans_classes = _partition(self.transition_matrix)

        self._is_irreducible = len(rec_classes) == 1 and len(
            trans_classes) == 0

        if not self._is_irreducible:
            self._trans_classes = _make_cat(trans_classes, n_states,
                                            self.adata.obs_names)
            self._rec_classes = _make_cat(rec_classes, n_states,
                                          self.adata.obs_names)
            logg.info(
                f"Found `{(len(rec_classes))}` recurrent and `{len(trans_classes)}` transient classes\n"
                f"Adding `.recurrent_classes`\n"
                f"       `.transient_classes`\n"
                f"       `.is_irreducible`\n"
                f"    Finish",
                time=start,
            )
        else:
            logg.warning(
                "The transition matrix is irreducible, cannot further partition it\n    Finish",
                time=start,
            )
Exemple #26
0
    def write(self,
              fname: Union[str, Path],
              ext: Optional[str] = "pickle") -> None:
        """
        %(pickleable.full_desc)s

        Parameters
        ----------
        %(pickleable.parameters)s

        Returns
        -------
        %(pickleable.returns)s
        """  # noqa

        fname = str(fname)
        if ext is not None:
            if not ext.startswith("."):
                ext = "." + ext
            if not fname.endswith(ext):
                fname += ext

        logg.debug(f"Writing to `{fname}`")

        with open(fname, "wb") as fout:
            if version_info[:2] > (3, 6):
                pickle.dump(self, fout)
            else:
                # we need to use PrecomputedKernel because Python3.6 can't pickle Enums
                # and they are present in VelocityKernel
                logg.warning(
                    "Saving kernel as `cellrank.tl.kernels.PrecomputedKernel`")
                orig_kernel = self.kernel
                self._kernel = PrecomputedKernel(self.kernel)
                try:
                    pickle.dump(self, fout)
                except Exception as e:
                    raise e
                finally:
                    self._kernel = orig_kernel
Exemple #27
0
def _write_graph_data(
    adata: AnnData,
    data: Union[np.ndarray, spmatrix],
    key: str,
):
    """
    Write graph data to :mod:`AnnData`.

    :module`anndata` >=0.7 stores `(n_obs x n_obs)` matrices in `.obsp` rather than `.uns`.
    This is for backward compatibility.

    Parameters
    ----------
    adata
        Annotated data object.
    data
        The graph data we want to write.
    key
        Key from either ``adata.uns`` or `adata.obsp``.

    Returns
    --------
    None
        Nothing, just writes the data.
    """

    try:
        adata.obsp[key] = data
        write_to = "obsp"

        if data.shape[0] != data.shape[1]:
            logg.warning(
                f"`adata.obsp` attribute should only contain square matrices, found shape `{data.shape}`"
            )

    except AttributeError:
        adata.uns[key] = data
        write_to = "uns"

    logg.debug(f"Writing graph data to `adata.{write_to}[{key!r}]`")
Exemple #28
0
    def _compute_transition_matrix(
        self,
        matrix: spmatrix,
        density_normalize: bool = True,
    ):
        # density correction based on node degrees in the KNN graph
        matrix = csr_matrix(matrix) if not isspmatrix_csr(matrix) else matrix

        if density_normalize:
            matrix = self._density_normalize(matrix)

        # check for zero-rows
        problematic_indices = np.where(
            np.array(matrix.sum(1)).flatten() == 0)[0]
        if len(problematic_indices):
            logg.warning(
                f"Detected `{len(problematic_indices)}` absorbing states in the transition matrix. "
                f"This matrix won't be reducible")
            matrix[problematic_indices, problematic_indices] = 1.0

        # setting this property automatically row-normalizes
        self.transition_matrix = matrix
        self._maybe_compute_cond_num()
Exemple #29
0
    def _compute_transition_matrix(
        self,
        matrix: Union[np.ndarray, spmatrix],
        density_normalize: bool = True,
        check_irreducibility: bool = False,
    ):
        if matrix.shape[0] != matrix.shape[1]:
            raise ValueError(
                f"Expected a square matrix, found `{matrix.shape}`.")
        if matrix.shape[0] != self.adata.n_obs:
            raise ValueError(
                f"Expected matrix to be of shape `{(self.adata.n_obs, self.adata.n_obs)}`, "
                f"found `{matrix.shape}`.")

        matrix = matrix.astype(_dtype)
        if issparse(matrix) and not isspmatrix_csr(matrix):
            matrix = csr_matrix(matrix)

        # density correction based on node degrees in the KNN graph
        if density_normalize:
            matrix = self._density_normalize(matrix)

        # check for zero-rows
        problematic_indices = np.where(
            np.array(matrix.sum(1)).flatten() == 0)[0]
        if len(problematic_indices):
            logg.warning(
                f"Detected `{len(problematic_indices)}` absorbing states in the transition matrix. "
                f"This matrix won't be irreducible")
            matrix[problematic_indices, problematic_indices] = 1.0

        if check_irreducibility:
            _irreducible(matrix)

        # setting this property automatically row-normalizes
        self.transition_matrix = matrix
        self._maybe_compute_cond_num()
Exemple #30
0
    def __init__(
        self,
        adata: AnnData,
        n_splines: Optional[int] = 10,
        spline_order: int = 3,
        distribution: str = "gamma",
        link: str = "log",
        max_iter: int = 2000,
        expectile: Optional[float] = None,
        use_default_conf_int: bool = False,
        grid: Optional[Mapping] = None,
        spline_kwargs: Mapping = MappingProxyType({}),
        **kwargs,
    ):
        term = s(
            0,
            spline_order=spline_order,
            n_splines=n_splines,
            penalties=["derivative", "l2"],
            **_filter_kwargs(s, **{**{"lam": 3}, **spline_kwargs}),
        )
        link = GamLinkFunction(link)
        distribution = GamDistribution(distribution)
        if distribution == GamDistribution.GAUSS:
            distribution = GamDistribution.NORMAL

        if expectile is not None:
            if not (0 < expectile < 1):
                raise ValueError(
                    f"Expected `expectile` to be in `(0, 1)`, found `{expectile}`."
                )
            if distribution != "normal" or link != "identity":
                logg.warning(
                    f"Expectile GAM works only with `normal` distribution and `identity` link function,"
                    f"found `{distribution!r}` distribution and {link!r} link functions."
                )
            model = ExpectileGAM(
                term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs
            )
        else:
            gam = _gams[
                distribution, link
            ]  # doing it like this ensure that user can specify scale
            kwargs["link"] = link.s
            kwargs["distribution"] = distribution.s
            model = gam(
                term,
                max_iter=max_iter,
                verbose=False,
                **_filter_kwargs(gam.__init__, **kwargs),
            )
        super().__init__(adata, model=model)
        self._use_default_conf_int = use_default_conf_int

        if grid is None:
            self._grid = None
        elif isinstance(grid, dict):
            self._grid = _copy(grid)
        elif isinstance(grid, str):
            self._grid = object() if grid == "default" else None
        else:
            raise TypeError(
                f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`."
            )