Ejemplo n.º 1
0
    def test_one_feature(self):
        # create some data
        adata = sc.datasets.blobs(n_observations=100, n_variables=1)

        # kmeans, louvain, leiden
        labels_kmeans = _cluster_X(adata.X, n_clusters=5, method="kmeans")
        labels_louvain = _cluster_X(adata.X, n_clusters=5, method="louvain")

        assert len(labels_kmeans) == len(labels_louvain) == adata.n_obs
Ejemplo n.º 2
0
    def compute_terminal_states(
        self,
        use: Optional[Union[int, Tuple[int], List[int], range]] = None,
        percentile: Optional[int] = 98,
        method: str = "kmeans",
        cluster_key: Optional[str] = None,
        n_clusters_kmeans: Optional[int] = None,
        n_neighbors: int = 20,
        resolution: float = 0.1,
        n_matches_min: Optional[int] = 0,
        n_neighbors_filtering: int = 15,
        basis: Optional[str] = None,
        n_comps: int = 5,
        scale: bool = False,
        en_cutoff: Optional[float] = 0.7,
        p_thresh: float = 1e-15,
    ) -> None:
        """
        Find approximate recurrent classes of the Markov chain.

        Filter to obtain recurrent states in left eigenvectors.
        Cluster to obtain approximate recurrent classes in right eigenvectors.

        Parameters
        ----------
        use
            Which or how many first eigenvectors to use as features for clustering/filtering.
            If `None`, use the `eigengap` statistic.
        percentile
            Threshold used for filtering out cells which are most likely transient states. Cells which are in the
            lower ``percentile`` percent of each eigenvector will be removed from the data matrix.
        method
            Method to be used for clustering. Must be one of `'louvain'`, `'leiden'` or `'kmeans'`.
        cluster_key
            If a key to cluster labels is given, :attr:`{fs}` will get associated with these for naming and colors.
        n_clusters_kmeans
            If `None`, this is set to ``use + 1``.
        n_neighbors
            If we use `'louvain'` or `'leiden'` for clustering cells, we need to build a KNN graph.
            This is the :math:`K` parameter for that, the number of neighbors for each cell.
        resolution
            Resolution parameter for `'louvain'` or `'leiden'` clustering. Should be chosen relatively small.
        n_matches_min
            Filters out cells which don't have at least n_matches_min neighbors from the same class.
            This filters out some cells which are transient but have been misassigned.
        n_neighbors_filtering
            Parameter for filtering cells. Cells are filtered out if they don't have at least ``n_matches_min``
            neighbors among their ``n_neighbors_filtering`` nearest cells.
        basis
            Key from :paramref`adata` ``.obsm`` to be used as additional features for the clustering.
        n_comps
            Number of embedding components to be use when ``basis`` is not `None`.
        scale
            Scale to z-scores. Consider using this if appending embedding to features.
        %(en_cutoff_p_thresh)s

        Returns
        -------
        None
            Nothing, but updates the following fields:

                - :attr:`{fsp}`
                - :attr:`{fs}`
        """

        def _compute_macrostates_prob() -> Series:
            """Compute a global score of being an approximate recurrent class."""

            # get the truncated eigendecomposition
            V, evals = eig["V_l"].real[:, use], eig["D"].real[use]

            # shift and scale
            V_pos = np.abs(V)
            V_shifted = V_pos - np.min(V_pos, axis=0)
            V_scaled = V_shifted / np.max(V_shifted, axis=0)

            # check the ranges are correct
            assert np.allclose(np.min(V_scaled, axis=0), 0), "Lower limit it not zero."
            assert np.allclose(np.max(V_scaled, axis=0), 1), "Upper limit is not one."

            # further scale by the eigenvalues
            V_eigs = V_scaled / evals

            # sum over cols and scale
            c_ = np.sum(V_eigs, axis=1)
            c = c_ / np.max(c_)

            return Series(c, index=self.adata.obs_names)

        def check_use(use) -> List[int]:
            if method not in ["kmeans", "louvain", "leiden"]:
                raise ValueError(
                    f"Invalid method `{method!r}`. Valid options are `'louvain'`, `'leiden'` or `'kmeans'`."
                )

            if use is None:
                use = eig["eigengap"] + 1  # add one b/c indexing starts at 0
            if isinstance(use, int):
                use = list(range(use))
            elif not isinstance(use, (tuple, list, range)):
                raise TypeError(
                    f"Argument `use` must be either `int`, `tuple`, `list` or `range`, "
                    f"found `{type(use).__name__!r}`."
                )
            else:
                if not all(map(lambda u: isinstance(u, int), use)):
                    raise TypeError("Not all values in `use` argument are integers.")
            use = list(use)

            if len(use) == 0:
                raise ValueError(
                    f"Number of eigenvector must be larger than `0`, found `{len(use)}`."
                )

            muse = max(use)
            if muse >= eig["V_l"].shape[1] or muse >= eig["V_r"].shape[1]:
                raise ValueError(
                    f"Maximum specified eigenvector `{muse}` is larger "
                    f'than the number of computed eigenvectors `{eig["V_l"].shape[1]}`. '
                    f"Use `.compute_eigendecomposition(k={muse})` to recompute the eigendecomposition."
                )

            return use

        eig = self._get(P.EIG)
        if eig is None:
            raise RuntimeError(
                "Compute eigendecomposition first as `.compute_eigendecomposition()`."
            )
        use = check_use(use)

        start = logg.info("Computing approximate recurrent classes")
        # we check for complex values only in the left, that's okay because the complex pattern
        # will be identical for left and right
        V_l, V_r = eig["V_l"][:, use], eig["V_r"].real[:, use]
        V_l = _complex_warning(V_l, use, use_imag=False)

        # compute a rc probability
        logg.debug("Computing probabilities of approximate recurrent classes")
        self._set(A.TERM_PROBS, _compute_macrostates_prob())

        # retrieve embedding and concatenate
        if basis is not None:
            bkey = f"X_{basis}"
            if bkey not in self.adata.obsm.keys():
                raise KeyError(f"Basis key `{bkey!r}` not found in `adata.obsm`")

            X_em = self.adata.obsm[bkey][:, :n_comps]
            X = np.concatenate([V_r, X_em], axis=1)
        else:
            logg.debug("Basis is `None`. Setting X equal to the right eigenvectors")
            X = V_r

        # filter out cells which are in the lowest q percentile in abs value in each eigenvector
        if percentile is not None:
            logg.debug("Filtering out cells according to percentile")
            if percentile < 0 or percentile > 100:
                raise ValueError(
                    f"Percentile must be in interval `[0, 100]`, found `{percentile}`."
                )
            cutoffs = np.percentile(np.abs(V_l), percentile, axis=0)
            ixs = np.sum(np.abs(V_l) < cutoffs, axis=1) < V_l.shape[1]
            X = X[ixs, :]

        # scale
        if scale:
            X = zscore(X, axis=0)

        # cluster X
        if method == "kmeans" and n_clusters_kmeans is None:
            n_clusters_kmeans = len(use) + (percentile is None)
            if X.shape[0] < n_clusters_kmeans:
                raise ValueError(
                    f"Filtering resulted in only {X.shape[0]} cell(s), insufficient to cluster into "
                    f"`{n_clusters_kmeans}` clusters. Consider decreasing the value of `percentile`."
                )

        logg.debug(
            f"Using `{use}` eigenvectors, basis `{basis!r}` and method `{method!r}` for clustering"
        )
        labels = _cluster_X(
            X,
            method=method,
            n_clusters=n_clusters_kmeans,
            n_neighbors=n_neighbors,
            resolution=resolution,
        )

        # fill in the labels in case we filtered out cells before
        if percentile is not None:
            rc_labels = np.repeat(None, self.adata.n_obs)
            rc_labels[ixs] = labels
        else:
            rc_labels = labels

        rc_labels = Series(rc_labels, index=self.adata.obs_names, dtype="category")
        rc_labels.cat.categories = list(rc_labels.cat.categories.astype("str"))

        # filtering to get rid of some of the left over transient states
        if n_matches_min > 0:
            logg.debug(f"Filtering according to `n_matches_min={n_matches_min}`")
            distances = _get_connectivities(
                self.adata, mode="distances", n_neighbors=n_neighbors_filtering
            )
            rc_labels = _filter_cells(
                distances, rc_labels=rc_labels, n_matches_min=n_matches_min
            )

        self.set_terminal_states(
            labels=rc_labels,
            cluster_key=cluster_key,
            en_cutoff=en_cutoff,
            p_thresh=p_thresh,
            add_to_existing=False,
            time=start,
        )