コード例 #1
0
    def compute_eigen(
        self,
        n_comps: int = 15,
        sym: Optional[bool] = None,
        sort: Literal['decrease', 'increase'] = 'decrease',
    ):
        """\
        Compute eigen decomposition of transition matrix.

        Parameters
        ----------
        n_comps
            Number of eigenvalues/vectors to be computed, set `n_comps = 0` if
            you need all eigenvectors.
        sym
            Instead of computing the eigendecomposition of the assymetric
            transition matrix, computed the eigendecomposition of the symmetric
            Ktilde matrix.

        Returns
        -------
        Writes the following attributes.

        eigen_values : numpy.ndarray
            Eigenvalues of transition matrix.
        eigen_basis : numpy.ndarray
             Matrix of eigenvectors (stored in columns).  `.eigen_basis` is
             projection of data matrix on right eigenvectors, that is, the
             projection on the diffusion components.  these are simply the
             components of the right eigenvectors and can directly be used for
             plotting.
        """
        np.set_printoptions(precision=10)
        if self._transitions_sym is None:
            raise ValueError('Run `.compute_transitions` first.')
        matrix = self._transitions_sym
        # compute the spectrum
        if n_comps == 0:
            evals, evecs = scipy.linalg.eigh(matrix)
        else:
            n_comps = min(matrix.shape[0] - 1, n_comps)
            # ncv = max(2 * n_comps + 1, int(np.sqrt(matrix.shape[0])))
            ncv = None
            which = 'LM' if sort == 'decrease' else 'SM'
            # it pays off to increase the stability with a bit more precision
            matrix = matrix.astype(np.float64)
            evals, evecs = scipy.sparse.linalg.eigsh(matrix,
                                                     k=n_comps,
                                                     which=which,
                                                     ncv=ncv)
            evals, evecs = evals.astype(np.float32), evecs.astype(np.float32)
        if sort == 'decrease':
            evals = evals[::-1]
            evecs = evecs[:, ::-1]
        logg.info('    eigenvalues of transition matrix\n'
                  '    {}'.format(str(evals).replace('\n', '\n    ')))
        if self._number_connected_components > len(evals) / 2:
            logg.warning('Transition matrix has many disconnected components!')
        self._eigen_values = evals
        self._eigen_basis = evecs
コード例 #2
0
def get_igraph_from_adjacency(adj: scipy.sparse.csr_matrix,
                              edge_type: str = None):
    """Get igraph graph from adjacency matrix.
    Better than Graph.Adjacency for sparse matrices

    Parameters
    ----------
    adj
        (weighted) adjacency matrix
    edge_type
        A type attribute added to all edges
    """
    g = ig.Graph(directed=False)
    g.add_vertices(adj.shape[0])  # this adds adjacency.shape[0] vertices

    sources, targets = scipy.sparse.triu(adj, k=1).nonzero()
    weights = adj[sources, targets].astype("float")
    g.add_edges(list(zip(sources, targets)))

    if isinstance(weights, np.matrix):
        weights = weights.A1

    g.es["weight"] = weights
    if edge_type is not None:
        g.es["type"] = edge_type

    if g.vcount() != adj.shape[0]:
        logging.warning(f"The constructed graph has only {g.vcount()} nodes. "
                        "Your adjacency matrix contained redundant nodes.")
    return g
コード例 #3
0
    def add_image(self, layer: str) -> bool:
        """
        Add a new :mod:`napari` image layer.

        Parameters
        ----------
        layer
            Layer in the underlying's :class:`ImageContainer` which contains the image.

        Returns
        -------
        `True` if the layer has been added, otherwise `False`.
        """
        if layer in self.view.layernames:
            self._handle_already_present(layer)
            return False

        img: np.ndarray = self.model.container.data[layer].transpose(
            "y", "x", ...).values
        if img.shape[-1] > 4:
            logg.warning(f"Unable to show image of shape `{img.shape}`")
            return False

        logg.info(f"Creating image `{layer}` layer")
        self.view.viewer.add_image(
            img_as_float(img),
            name=layer,
            rgb=True,
            colormap=self.model.cmap,
            blending=self.model.blending,
        )

        return True
コード例 #4
0
ファイル: _container.py プロジェクト: sophial05/squidpy
    def _(self,
          img: xr.DataArray,
          copy: bool = True,
          **_: Any) -> xr.DataArray:
        logg.debug(f"Loading data `xarray.DataArray` of shape `{img.shape}`")

        if img.ndim == 2:
            img = img.expand_dims("channels", -1)
        if img.ndim != 3:
            raise ValueError(
                f"Expected image to have `3` dimensions, found `{img.ndim}`.")

        mapping: Dict[Hashable, str] = {}
        if "y" not in img.dims:
            logg.warning(
                f"Dimension `y` not found in the data. Assuming it's `{img.dims[0]}`"
            )
            mapping[img.dims[0]] = "y"
        if "x" not in img.dims:
            logg.warning(
                f"Dimension `x` not found in the data. Assuming it's `{img.dims[1]}`"
            )
            mapping[img.dims[1]] = "x"

        img = img.rename(mapping)
        channel_dim = [d for d in img.dims if d not in ("y", "x")][0]
        try:
            img = img.reset_index(dims_or_levels=channel_dim, drop=True)
        except KeyError:
            # might not be present, ignore
            pass

        return img.copy() if copy else img
コード例 #5
0
    def _set_iroot_via_xroot(self, xroot):
        """Determine the index of the root cell.

        Given an expression vector, find the observation index that is closest
        to this vector.

        Parameters
        ----------
        xroot : np.ndarray
            Vector that marks the root cell, the vector storing the initial
            condition, only relevant for computing pseudotime.
        """
        if self._adata.shape[1] != xroot.size:
            raise ValueError('The root vector you provided does not have the '
                             'correct dimension.')
        # this is the squared distance
        dsqroot = 1e10
        iroot = 0
        for i in range(self._adata.shape[0]):
            diff = self._adata.X[i, :] - xroot
            dsq = diff @ diff
            if dsq < dsqroot:
                dsqroot = dsq
                iroot = i
                if np.sqrt(dsqroot) < 1e-10: break
        logg.debug(f'setting root index to {iroot}')
        if self.iroot is not None and iroot != self.iroot:
            logg.warning(
                f'Changing index of iroot from {self.iroot} to {iroot}.')
        self.iroot = iroot
コード例 #6
0
ファイル: __init__.py プロジェクト: zktuong/scirpy
def _get_igraph_from_adjacency(adj: csr_matrix, simplify=True):
    """Get an undirected igraph graph from adjacency matrix.
    Better than Graph.Adjacency for sparse matrices.

    Parameters
    ----------
    adj
        sparse, weighted, symmetrical adjacency matrix.
    """
    sources, targets = adj.nonzero()
    weights = adj[sources, targets]
    if isinstance(weights, np.matrix):
        weights = weights.A1
    if isinstance(weights, csr_matrix):
        # this is the case when len(sources) == len(targets) == 0, see #236
        weights = weights.toarray()

    g = ig.Graph(directed=not simplify)
    g.add_vertices(adj.shape[0])  # this adds adjacency.shape[0] vertices
    g.add_edges(list(zip(sources, targets)))

    g.es["weight"] = weights

    if g.vcount() != adj.shape[0]:
        logging.warning(
            f"The constructed graph has only {g.vcount()} nodes. "
            "Your adjacency matrix contained redundant nodes.")  # type: ignore

    if simplify:
        # since we start from a symmetrical matrix, and the graph is undirected,
        # it is fine to take either of the two edges when simplifying.
        g.simplify(combine_edges="first")

    return g
コード例 #7
0
ファイル: _utils.py プロジェクト: dpeerlab/cellrank
def _vec_mat_corr(X: Union[np.ndarray, spmatrix], y: np.ndarray) -> np.ndarray:
    """
    Compute the correlation between columns in matrix X and a vector y.

    Return NaN for genes which don't vary across cells.

    Params
    ------
    X
        Matrix of `NxM` elements.
    y:
        Vector of `M` elements.

    Returns
    -------
    :class:`numpy.ndarray`
        The computed correlation.
    """

    X_bar, y_std, n = np.array(
        X.mean(axis=0)).reshape(-1), np.std(y), X.shape[0]
    denom = X.T.dot(y) - n * X_bar * np.mean(y)
    nom = ((n - 1) * np.std(X.A, axis=0) * y_std if issparse(X) else
           (X.shape[0] - 1) * np.std(X, axis=0) * y_std)

    if np.sum(nom == 0) > 0:
        logg.warning(
            f"No variation found in `{np.sum(nom==0)}` genes. Setting correlation for these to `NaN`"
        )

    return denom / nom
コード例 #8
0
ファイル: _utils.py プロジェクト: dpeerlab/cellrank
def _complex_warning(X: np.array,
                     use: Union[list, int, tuple, range],
                     use_imag: bool = False) -> np.ndarray:
    """
    Check for imaginary components in columns of X specified by `use`.

    Params
    ------
    X
        Matrix containing the eigenvectors
    use
        Selection of columns of `X`
    use_imag
        For eigenvectors that are complex, use real or imaginary part

    Returns
    -------
    class:`numpy.ndarray`
        X_
    """

    complex_mask = np.sum(X.imag != 0, axis=0) > 0
    complex_ixs = np.array(use)[np.where(complex_mask)[0]]
    complex_key = "imaginary" if use_imag else "real"
    if len(complex_ixs) > 0:
        logg.warning(
            f"The eigenvectors with indices {complex_ixs} have an imaginary part. Showing their {complex_key} part."
        )
    X_ = X.real
    if use_imag:
        X_[:, complex_mask] = X.imag[:, complex_mask]

    return X_
コード例 #9
0
    def _read_from_adata(self, **kwargs):
        """
        Import the base-KNN graph and check for symmetry and connectivity.
        """

        if not has_neighs(self.adata):
            raise KeyError("Compute KNN graph first as `scanpy.pp.neighbors()`.")

        self._conn = get_neighs(self.adata, "connectivities").astype(_dtype)

        start = logg.debug("Checking the KNN graph for connectedness")
        if not is_connected(self._conn):
            logg.warning("KNN graph is not connected", time=start)

        start = logg.debug("Checking the KNN graph for symmetry")
        if not is_symmetric(self._conn):
            logg.warning("KNN graph is not symmetric", time=start)

        variance_key = kwargs.pop("variance_key", None)
        if variance_key is not None:
            logg.debug(f"DEBUG: Loading variances from `adata.uns[{variance_key!r}]`")
            variance_key = f"{variance_key}_variances"
            if variance_key in self.adata.uns.keys():
                # keep it sparse
                self._variances = csr_matrix(
                    self.adata.uns[variance_key].astype(_dtype)
                )
            else:
                self._variances = None
                logg.debug(
                    f"DEBUG: Unable to load variances`{variance_key}` from `adata.uns`"
                )
        else:
            logg.debug("DEBUG: No variance key specified")
コード例 #10
0
def check_var_names_type(var_names, var_group_labels, var_group_positions):
    """
    checks if var_names is a dict. Is this is the cases, then set the
    correct values for var_group_labels and var_group_positions

    Returns
    -------
    var_names, var_group_labels, var_group_positions

    """
    if isinstance(var_names, cabc.Mapping):
        if var_group_labels is not None or var_group_positions is not None:
            logg.warning(
                "`var_names` is a dictionary. This will reset the current "
                "value of `var_group_labels` and `var_group_positions`.")
        var_group_labels = []
        _var_names = []
        var_group_positions = []
        start = 0
        for label, vars_list in var_names.items():
            if isinstance(vars_list, str):
                vars_list = [vars_list]
            # use list() in case var_list is a numpy array or pandas series
            _var_names.extend(list(vars_list))
            var_group_labels.append(label)
            var_group_positions.append((start, start + len(vars_list) - 1))
            start += len(vars_list)
        var_names = _var_names

    elif isinstance(var_names, str):
        var_names = [var_names]

    return var_names, var_group_labels, var_group_positions
コード例 #11
0
def test_formats(capsys, logging_state):
    s.logfile = sys.stderr
    s.verbosity = Verbosity.debug
    l.error('0')
    assert capsys.readouterr().err == 'ERROR: 0\n'
    l.warning('1')
    assert capsys.readouterr().err == 'WARNING: 1\n'
    l.info('2')
    assert capsys.readouterr().err == '2\n'
    l.hint('3')
    assert capsys.readouterr().err == '--> 3\n'
    l.debug('4')
    assert capsys.readouterr().err == '    4\n'
コード例 #12
0
ファイル: _lineage.py プロジェクト: dpeerlab/cellrank
def _remove_zero_rows(a: Lineage, b: Lineage) -> Tuple[Lineage, Lineage]:
    if a.shape[0] != b.shape[0]:
        raise ValueError("Lineage objects have unequal cell numbers")

    bool_a = (a.X == 0).any(axis=1)
    bool_b = (b.X == 0).any(axis=1)
    mask = ~np.logical_or(bool_a, bool_b)

    logg.warning(
        f"Removed {a.shape[0] - np.sum(mask)} rows because they contained zeros"
    )

    return a[mask, :], b[mask, :]
コード例 #13
0
ファイル: _utils.py プロジェクト: dpeerlab/cellrank
def _series_from_one_hot_matrix(a: np.array,
                                index: Optional[Iterable] = None,
                                names: Optional[Iterable] = None) -> pd.Series:
    """
    Create a pandas Series based on a one-hot encoded matrix.

    Params
    ------
    a
        One-hot encoded membership matrix, of shape (`n_samples x n_clusters`) i.e. a `1` in position `i, j`
        signifies that sample `i` belongs to cluster `j`.
    index
        Index for the Series. Careful, if this is not given, categories are removed when writing to AnnData.

    Returns
    -------
    cluster_series
        Pandas Series, indicating cluster membership for each sample. The dtype of the categories is `str`
        and samples that belong to no cluster are assigned `NaN`.
    """
    n_samples, n_clusters = a.shape
    if not isinstance(a, np.ndarray):
        raise TypeError(
            f"Expected `a` to be of type `numpy.ndarray`, found `{type(a).__name__!r}`."
        )
    a = np.asarray(a)  # change the type in case a lineage object was passed.
    if a.dtype != np.bool:
        raise TypeError(
            f"Expected `a`'s elements to be boolean, found `{a.dtype.name}`.")

    if not np.all(a.sum(axis=1) <= 1):
        raise ValueError("Not all items are one-hot encoded or empty.")
    if (a.sum(0) == 0).any():
        logg.warning(f"Detected {np.sum((a.sum(0) == 0))} empty categorie(s) ")

    if index is None:
        index = range(n_samples)
    if names is not None:
        if len(names) != n_clusters:
            raise ValueError(
                f"Shape mismatch, length of `names` is `{len(names)}`, but `n_clusters` = {n_clusters}."
            )
    else:
        names = np.arange(n_clusters).astype("str")

    target_series = pd.Series(index=index, dtype="category")
    for (vec, name) in zip(a.T, names):
        target_series.cat.add_categories(name, inplace=True)
        target_series[np.where(vec)[0]] = name

    return target_series
コード例 #14
0
def _get_reference(
    adata: AnnData,
    reference_key: Union[str, None],
    reference_cat: Union[None, str, Sequence[str]],
    reference: Union[np.ndarray, None],
) -> np.ndarray:
    """Parameter validation extraction of reference gene expression.

    If multiple reference categories are given, compute the mean per
    category.

    Returns a 2D array with reference categories in rows, cells in columns.
    If there's just one category, it's still a 2D array.
    """
    if reference is None:
        if reference_key is None or reference_cat is None:
            logging.warning(
                "Using mean of all cells as reference. For better results, "
                "provide either `reference`, or both `reference_key` and `reference_cat`. "
            )  # type: ignore
            reference = np.mean(adata.X, axis=0)

        else:
            obs_col = adata.obs[reference_key]
            if isinstance(reference_cat, str):
                reference_cat = [reference_cat]
            reference_cat = np.array(reference_cat)
            reference_cat_in_obs = np.isin(reference_cat, obs_col)
            if not np.all(reference_cat_in_obs):
                raise ValueError(
                    "The following reference categories were not found in "
                    "adata.obs[reference_key]: "
                    f"{reference_cat[~reference_cat_in_obs]}")

            reference = np.vstack([
                np.mean(adata.X[obs_col == cat, :], axis=0)
                for cat in reference_cat
            ])

    if reference.ndim == 1:
        reference = reference[np.newaxis, :]

    if reference.shape[1] != adata.shape[1]:
        raise ValueError(
            "Reference must match the number of genes in AnnData. ")

    return reference
コード例 #15
0
    def export(self, _: Viewer) -> None:
        """Export shapes into :class:`AnnData` object."""
        for layer in self.view.layers:
            if not isinstance(layer, Shapes) or not layer.selected:
                continue
            if not len(layer.data):
                logg.warning(
                    f"Shape layer `{layer.name}` has no visible shapes")
                continue

            key = f"{layer.name}_{self.model.key_added}"

            logg.info(
                f"Adding `adata.obs[{key!r}]`\n       `adata.uns[{key!r}]['meshes']`"
            )
            self._save_shapes(layer, key=key)
            self._update_obs_items(key)
コード例 #16
0
def _regress_out_chunk(data):
    # data is a tuple containing the selected columns from adata.X
    # and the regressors dataFrame
    data_chunk = data[0]
    regressors = data[1]
    variable_is_categorical = data[2]

    responses_chunk_list = []
    import statsmodels.api as sm
    from statsmodels.tools.sm_exceptions import PerfectSeparationError

    #output = np.zeros((data_chunk.shape[1],data_chunk.shape[0]))
    for col_index in range(data_chunk.shape[1]):

        # if all values are identical, the statsmodel.api.GLM throws an error;
        # but then no regression is necessary anyways...
        if not (data_chunk[:, col_index] != data_chunk[0, col_index]).any():
            responses_chunk_list.append(data_chunk[:, col_index])
            continue

        if variable_is_categorical:
            regres = np.c_[np.ones(regressors.shape[0]), regressors[:, col_index]]
        else:
            regres = regressors
        try:
            if regr_type==1:
                result = sm.GLM(data_chunk[:, col_index], regres, family=sm.families.Gaussian()).fit(maxiter=1,tol=1e-1)
            elif regr_type==2:
                result = sm.GLM(data_chunk[:, col_index], regres, family=sm.families.Gaussian()).fit(method='lbfgs',m=10,factr=1e24,maxfun=1)
            #result = sm.GLM(data_chunk[:, col_index], regres, family=sm.families.Gaussian()).fit(method='newton', tol=1e-1)
            new_column = result.resid_response
            #temp = np.zeros((data_chunk.shape[0],int(data_chunk.shape[0]**0.5))).T
            #temp += data_chunk[:,col_index]
            #new_column = temp[-1].copy()
            #del(temp)
            #new_column = result.resid_response.copy()
        except PerfectSeparationError:  # this emulates R's behavior
            logg.warning('Encountered PerfectSeparationError, setting to 0 as in R.')
            new_column = np.zeros(data_chunk.shape[0])

        responses_chunk_list.append(new_column)
        #output[col_index] = new_column
        #del(new_column)
        #del(result)

    return np.vstack(responses_chunk_list)
コード例 #17
0
 def _init_iroot(self):
     self.iroot = None
     # set iroot directly
     if 'iroot' in self._adata.uns:
         if self._adata.uns['iroot'] >= self._adata.n_obs:
             logg.warning(
                 f'Root cell index {self._adata.uns["iroot"]} does not '
                 f'exist for {self._adata.n_obs} samples. It’s ignored.')
         else:
             self.iroot = self._adata.uns['iroot']
         return
     # set iroot via xroot
     xroot = None
     if 'xroot' in self._adata.uns: xroot = self._adata.uns['xroot']
     elif 'xroot' in self._adata.var: xroot = self._adata.var['xroot']
     # see whether we can set self.iroot using the full data matrix
     if xroot is not None and xroot.size == self._adata.shape[1]:
         self._set_iroot_via_xroot(xroot)
コード例 #18
0
ファイル: _utils.py プロジェクト: stuarteberg/schist
def get_graph_tool_from_adjacency(adjacency, directed=None):
    """Get graph-tool graph from adjacency matrix."""
    idx = np.nonzero(np.triu(adjacency.todense(), 1))
    weights = adjacency[idx]
    if isinstance(weights, np.matrix):
        weights = weights.A1
    g = gt.Graph(directed=directed)
    g.add_edge_list(np.transpose(idx))  # add
    try:
        ew = g.new_edge_property("double")
        ew.a = weights
        g.ep['weight'] = ew
    except:
        pass
    if g.num_vertices() != adjacency.shape[0]:
        logg.warning(
            f'The constructed graph has only {g.num_vertices()} nodes. '
            'Your adjacency matrix contained redundant nodes.')
    return g
コード例 #19
0
    def __init__(
        self,
        adata: AnnData,
        *,
        metric: Union[Literal["alignment", "identity", "levenshtein",
                              "hamming"], DistanceCalculator, ] = "identity",
        cutoff: Union[int, None] = None,
        receptor_arms: Literal["VJ", "VDJ", "all", "any"] = "all",
        dual_ir: Literal["primary_only", "all", "any"] = "primary_only",
        sequence: Literal["aa", "nt"] = "aa",
    ):
        """Class to compute Neighborhood graphs of CDR3 sequences.

        For documentation of the parameters, see :func:`ir_neighbors`.
        """
        start = logging.info("Initializing IrNeighbors object...")
        if metric == "identity" and cutoff != 0:
            raise ValueError("Identity metric only works with cutoff == 0")
        if metric != "identity" and cutoff == 0:
            logging.warning(f"Running with {metric} metric, but cutoff == 0. ")
        if sequence == "nt" and metric == "alignment":
            raise ValueError(
                "Using nucleotide sequences with alignment metric is not supported. "
            )
        if receptor_arms not in ["VJ", "VDJ", "all", "any"]:
            raise ValueError(
                "Invalid value for `receptor_arms`. Note that starting with v0.5 "
                "`TRA` and `TRB` are not longer valid values.")
        if dual_ir not in ["primary_only", "all", "any"]:
            raise ValueError("Invalid value for `dual_ir")
        if sequence not in ["aa", "nt"]:
            raise ValueError("Invalid value for `sequence`")
        self.adata = adata
        self.metric = metric
        self.cutoff = cutoff
        self.receptor_arms = receptor_arms
        self.dual_ir = dual_ir
        self.sequence = sequence
        self._build_index_dict()
        self._dist_mat = None
        logging.info("Finished initalizing IrNeighbors object. ", time=start)
コード例 #20
0
def neighbors(
    adata: AnnData,
    use_rep: str = "cnv_pca",
    key_added: str = "cnv_neighbors",
    inplace: bool = True,
    **kwargs,
):
    """Compute the neighborhood graph based on the result from
    :func:`infercnvpy.tl.infercnv`.

    Parameters
    ----------
    use_rep
        Key under which the PCA of the results of :func:`infercnvpy.tl.infercnv`
        are stored in anndata. If not present, attempts to run :func:`infercnvpy.tl.pca`
        with default parameters.
    key_added
        Distances are stored in .obsp[key_added+’_distances’] and connectivities in
        .obsp[key_added+’_connectivities’].
    inplace
        If `True`, store the neighborhood graph in adata, otherwise return
        the distance and connectivity matrices.
    **kwargs
        Arguments passed to :func:`scanpy.pp.neighbors`.

    Returns
    -------
    Depending on the value of inplace, updates anndata or returns the distance
    and connectivity matrices.
    """
    if f"X_{use_rep}" not in adata.obsm and use_rep == "cnv_pca":
        logging.warning(
            "X_cnv_pca not found in adata.obsm. Computing PCA with default parameters"
        )  # type: ignore
        tl.pca(adata)

    return sc.pp.neighbors(adata,
                           use_rep=f"X_{use_rep}",
                           key_added=key_added,
                           copy=not inplace,
                           **kwargs)
コード例 #21
0
ファイル: _base_estimator.py プロジェクト: dpeerlab/cellrank
    def compute_partition(self) -> None:
        """
        Compute communication classes for the Markov chain.

        Returns
        -------
        None
            Nothing, but updates the following fields:
                - :paramref:`recurrent_classes`
                - :paramref:`transient_classes`
                - :paramref:`irreducible`
        """

        start = logg.info("Computing communication classes")

        rec_classes, trans_classes = partition(self._T)

        self._is_irreducible = len(rec_classes) == 1 and len(
            trans_classes) == 0

        if not self._is_irreducible:
            self._trans_classes = _make_cat(trans_classes, self._n_states,
                                            self._adata.obs_names)
            self._rec_classes = _make_cat(rec_classes, self._n_states,
                                          self._adata.obs_names)
            self._adata.obs[f"{self._rc_key}_rec_classes"] = self._rec_classes
            self._adata.obs[
                f"{self._rc_key}_trans_classes"] = self._trans_classes
            logg.info(
                f"Found `{(len(rec_classes))}` recurrent and `{len(trans_classes)}` transient classes\n"
                f"Adding `.recurrent_classes`\n"
                f"       `.transient_classes`\n"
                f"       `.irreducible`\n"
                f"    Finish",
                time=start,
            )
        else:
            logg.warning(
                "The transition matrix is irreducible - cannot further partition it\n    Finish",
                time=start,
            )
コード例 #22
0
ファイル: tools.py プロジェクト: gtca/muon
    def _affinity_matrix(dist, k, sigma):
        """
        Compute the affinity matrix for a distance matrix

        Reference implementation can be found in the SNFtool R package:
        https://github.com/cran/SNFtool/blob/master/R/affinityMatrix.R

        PARAMETERS
        ----------
        mdata:
                MuData object
        k: int (default: 20)
                Number of neighbours to be used in the K-nearest neighbours step
        sigma: float (default: 0.5)
                Variance for the local model when calculating affinity matrices
        """
        dist = (dist + dist.T) / 2
        if issparse(dist):
            dist.setdiag(0)
            dist.eliminate_zeros()
        else:
            np.fill_diagonal(dist, 0)

        # FIXME: adopt for sparse matrices
        if issparse(dist):
            logging.warning(
                f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Using dense distance matrix when computing affinity matrix..."
            )
            dist = dist.todense()
        sorted_columns = np.apply_along_axis(np.sort, 1, dist)

        def finite_mean(x, *args, **kwargs):
            return np.mean(x[~np.isinf(x)], *args, **kwargs)

        means = np.apply_along_axis(finite_mean, 1, sorted_columns[:, 1 : k + 1]) + eps
        sig = np.add.outer(means, means) / 3 + dist / 3 + eps
        densities = stats.norm(0, sigma * sig).pdf(dist)

        w = (densities + densities.T) / 2
        return w
コード例 #23
0
ファイル: _utils.py プロジェクト: stuarteberg/schist
def read(prefix: str = 'adata',
         key: str = 'nsbm',
         h5ad_fname: Optional[str] = None,
         pkl_fname: Optional[str] = None) -> Optional[AnnData]:
    """Read anndata object when a NestedBlockState has been saved separately.
    This function reads the h5ad and the pkl files, then rebuilds the `adata` properly,
    returning it to the user. Note that if pkl is not found, an AnnData object
    is returned anyway

    Parameters
    ----------
    prefix
        The prefix for .h5ad and .pkl files, it is supposed to be the same for 
        both. If this is not, specify file names (see below)
    key
        The slot in `AnnData.uns` in which nsbm information is placed
    h5ad_filename
        If `prefix` is not shared between h5ad and pkl, specify the h5ad file here
    pkl_filename
        If `prefix` is not shared between h5ad and pkl, specify the pkl file here
    """
    if not h5ad_fname:
        h5ad_fname = "%s.h5ad" % prefix
    if not pkl_fname:
        pkl_fname = "%s.pkl" % prefix

    # read the anndata
    adata = read_h5ad(h5ad_fname)

    try:
        with open(pkl_fname, 'rb') as fh:
            state = pickle.load(fh)
            adata.uns[key]['state'] = state
    except IOError:
        logg.warning(
            f'The specified file for state {pkl_fname} does not exist. '
            'Proceeding anyway')
        pass
    return adata
コード例 #24
0
    def maybe_create_lineage(direction: Direction):
        lin_key = str(LinKey.FORWARD if direction ==
                      Direction.FORWARD else LinKey.BACKWARD)
        names_key, colors_key = _lin_names(lin_key), _colors(lin_key)
        if lin_key in adata.obsm.keys():
            n_cells, n_lineages = adata.obsm[lin_key].shape
            logg.info(
                f"Creating {'forward' if direction == Direction.FORWARD else 'backward'} `Lineage` object"
            )

            if names_key not in adata.uns.keys():
                logg.warning(
                    f"Lineage names not found in `adata.uns[{names_key!r}]`, creating dummy names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            elif len(adata.uns[names_key]) != n_lineages:
                logg.warning(
                    f"Lineage names are don't have the required length ({n_lineages}), creating dummy names"
                )
                names = [f"Lineage {i}" for i in range(n_lineages)]
            else:
                logg.info("Succesfully loaded names")
                names = adata.uns[names_key]

            if colors_key not in adata.uns.keys():
                logg.warning(
                    f"Lineage colors not found in `adata.uns[{colors_key!r}]`, creating new colors"
                )
                colors = _create_categorical_colors(n_lineages)
            elif len(adata.uns[colors_key]) != n_lineages or not all(
                    map(lambda c: is_color_like(c), adata.uns[colors_key])):
                logg.warning(
                    f"Lineage colors don't have the required length ({n_lineages}) "
                    f"or are not color-like, creating new colors")
                colors = _create_categorical_colors(n_lineages)
            else:
                logg.info("Succesfully loaded colors")
                colors = adata.uns[colors_key]

            adata.obsm[lin_key] = Lineage(adata.obsm[lin_key],
                                          names=names,
                                          colors=colors)
            adata.uns[colors_key] = colors
            adata.uns[names_key] = names
        else:
            logg.debug(
                f"DEBUG: Unable to load {'forward' if direction == Direction.FORWARD else 'backward'} "
                f"`Lineage` from `adata.obsm[{lin_key!r}]`")
コード例 #25
0
ファイル: __init__.py プロジェクト: rgranit/infercnvpy
def tsne(
    adata: AnnData,
    use_rep: str = "cnv_pca",
    key_added: str = "cnv_tsne",
    inplace: bool = True,
    **kwargs,
):
    """Compute the t-SNE on the result of :func:`infercnvpy.tl.infercnv`.

    Thin wrapper around :func:`scanpy.tl.tsne`

    Parameters
    ----------
    adata
        annotated data matrix
    use_rep
        Key under which the result of :func:`infercnvpy.tl.pca` is stored
        in adata
    key_added
        Key under which the result of t-SNE will be stored in adata.obsm
    inplace
        If True, store the result in adata.obsm, otherwise return the result of t-SNE.
    **kwargs
        Additional arguments passed to :func:`scanpy.tl.tsne`.
    """
    if f"X_{use_rep}" not in adata.obsm and use_rep == "cnv_pca":
        logging.warning(
            "X_cnv_pca not found in adata.obsm. Computing PCA with default parameters"
        )  # type: ignore
        pca(adata)
    tmp_adata = sc.tl.tsne(adata, use_rep=f"X_cnv_pca", copy=True, **kwargs)

    if inplace:
        adata.obsm[f"X_{key_added}"] = tmp_adata.obsm["X_tsne"]
    else:
        return tmp_adata.obsm["X_tsne"]
コード例 #26
0
ファイル: _base_estimator.py プロジェクト: dpeerlab/cellrank
    def _detect_cc_stages(self,
                          rc_labels: Series,
                          p_thresh: float = 1e-15) -> None:
        """
        Detect cell-cycle driven start or endpoints.

        Params
        ------
        rc_labels
            Approximate recurrent classes.
        p_thresh
            P-value threshold for the rank-sum test for the group to be considered cell-cycle driven.
        Returns
        -------
        None
            Nothing, but warns if a group is cell-cycle driven.
        """

        # initialise the groups (start or end clusters) and scores
        groups = rc_labels.cat.categories
        scores = []
        if self._G2M_score is not None:
            scores.append(self._G2M_score)
        if self._S_score is not None:
            scores.append(self._S_score)

        # loop over groups and scores
        for group in groups:
            mask = rc_labels == group
            for score in scores:
                a, b = score[mask], score[~mask]
                result = ranksums(a, b)
                if result.statistic > 0 and result.pvalue < p_thresh:
                    logg.warning(
                        f"Group `{group}` appears to be cell-cycle driven")
                    break
コード例 #27
0
def transition_matrix(
    adata: AnnData,
    vkey: str = "velocity",
    backward: bool = False,
    self_transitions: Optional[str] = None,
    sigma_corr: Optional[float] = None,
    diff_kernel: Optional[str] = None,
    weight_diffusion: float = 0.2,
    density_normalize: bool = True,
    backward_mode: str = "transpose",
    inplace: bool = True,
) -> csr_matrix:
    """
    Computes transition probabilities from velocity graph.

    THIS FUNCTION HAS BEEN DEPRECATED.
    Interact with kernels via the Kernel class or via cellrank.tools_transition_matrix.transition_matrix

    Employs ideas of both scvelo as well as velocyto.

    Parameters
    --------
    adata : :class:`anndata.AnnData`
        Annotated Data Matrix
    vkey
        Name of the velocity estimates to be used
    backward
        Whether to use the transition matrix to push forward (`False`) or to pull backward (`True`)
    self_transitions
        How to fill the diagonal. Can be either 'velocyto' or 'scvelo'. Two diffent
        heuristics are used. Can prevent dividing by zero in unlucky sitatuations for the
        reverse process
    sigma_corr
        Kernel width for exp kernel to be used to compute transition probabilities
        from the velocity graph. If None, the median cosine correlation of all
        potisive cosine correlations will be used.
    diff_kernel
        Whether to multiply the velocity connectivities with transcriptomic distances to make them more robust.
        Options are ('sum', 'mult', 'both')
    weight_diffusion
        Relative weight given to the diffusion kernel. Must be in [0, 1]. Only matters when using 'sum' or 'both'
        for the diffusion kernel.
    density_normalize
        Whether to use the transcriptomic KNN graph for density normalization as performed in scanpy when
        computing diffusion maps
    backward_mode
        Options are ['transpose', 'negate'].
    inplace
        If True, adds to adata. Otherwise returns.

    Returns
    --------
    T: :class:`scipy.sparse.csr_matrix`
        Transition matrix
    """
    logg.info("Computing transition probability from velocity graph")

    from datetime import datetime

    print(datetime.now())

    # get the direction of the process
    direction = Direction.BACKWARD if backward else Direction.FORWARD

    # get the velocity correlations
    if (vkey + "_graph" not in adata.uns.keys()) or (vkey + "_graph_neg"
                                                     not in adata.uns.keys()):
        raise ValueError(
            "You need to run `tl.velocity_graph` first to compute cosine correlations"
        )
    velo_corr, velo_corr_neg = (
        csr_matrix(adata.uns[vkey + "_graph"]).copy(),
        csr_matrix(adata.uns[vkey + "_graph_neg"]).copy(),
    )
    velo_corr_comb_ = (velo_corr + velo_corr_neg).astype(np.float64)
    if backward:
        if backward_mode == "negate":
            velo_corr_comb = velo_corr_comb_.multiply(-1)
        elif backward_mode == "transpose":
            velo_corr_comb = velo_corr_comb_.T
        else:
            raise ValueError(f"Unknown backward_mode `{backward_mode}`.")
    else:
        velo_corr_comb = velo_corr_comb_
    med_corr = np.median(np.abs(velo_corr_comb.data))

    # compute the raw transition matrix. At the moment, this is just an exponential kernel
    logg.debug("DEBUG: Computing the raw transition matrix")
    if sigma_corr is None:
        sigma_corr = 1 / med_corr
    velo_graph = velo_corr_comb.copy()
    velo_graph.data = np.exp(velo_graph.data * sigma_corr)

    # should I row-_normalize the transcriptomic connectivities?
    if diff_kernel is not None or density_normalize:
        params = _get_neighs_params(adata)
        logg.debug(
            f'DEBUG: Using KNN graph computed in basis {params.get("use_rep", "Unknown")!r} '
            'with {params["n_neighbors"]} neighbors')
        trans_graph = _get_neighs(adata, "connectivities")
        dev = norm((trans_graph - trans_graph.T), ord="fro")
        if dev > 1e-4:
            logg.warning("KNN base graph not symmetric, `dev={dev}`")

    # KNN smoothing
    if diff_kernel is not None:
        logg.debug("DEBUG: Smoothing KNN graph with diffusion kernel")
        velo_graph = _knn_smooth(diff_kernel, velo_graph, trans_graph,
                                 weight_diffusion)
    # return velo_graph

    # set the diagonal elements. This is important especially for the backwards direction
    logg.debug("DEBUG: Setting diagonal elements")
    velo_graph = _self_loops(self_transitions, velo_graph)

    # density normalisation - taken from scanpy
    if density_normalize:
        logg.debug("DEBUG: Density correcting the velocity graph")
        velo_graph = density_normalization(velo_graph, trans_graph)

    # normalize
    T = _normalize(velo_graph)

    if not inplace:
        logg.info("Computed transition matrix")
        return T

    if _transition(direction) in adata.uns.keys():
        logg.warning(
            f"`.uns` already contains a field `{_transition(direction)!r}`. Overwriting"
        )

    params = {
        "backward": backward,
        "self_transitions": self_transitions,
        "sigma_corr": np.round(sigma_corr, 3),
        "diff_kernel": diff_kernel,
        "weight_diffusion": weight_diffusion,
        "density_normalize": density_normalize,
    }

    adata.uns[_transition(direction)] = {"T": T, "params": params}
    logg.info(
        f"Computed transition matrix and added the key `{_transition(direction)!r}` to `adata.uns`"
    )
コード例 #28
0
ファイル: _louvain.py プロジェクト: brianhie/trajectorama
def louvain(
    adata: AnnData,
    resolution: Optional[float] = None,
    random_state: Optional[Union[int, RandomState]] = 0,
    log_fname: str = '',
    restrict_to: Optional[Tuple[str, Sequence[str]]] = None,
    key_added: Optional[str] = 'louvain',
    adjacency: Optional[spmatrix] = None,
    flavor: str = 'vtraag',
    directed: bool = True,
    use_weights: bool = False,
    partition_type: Optional[Type[MutableVertexPartition]] = None,
    partition_kwargs: Optional[Mapping[str, Any]] = None,
    copy: bool = False,
) -> Optional[AnnData]:
    """Cluster cells into subgroups [Blondel08]_ [Levine15]_ [Traag17]_.

    Cluster cells using the Louvain algorithm [Blondel08]_ in the implementation
    of [Traag17]_. The Louvain algorithm has been proposed for single-cell
    analysis by [Levine15]_.

    This requires having ran :func:`~scanpy.pp.neighbors` or :func:`~scanpy.external.pp.bbknn` first,
    or explicitly passing a ``adjacency`` matrix.

    Parameters
    ----------
    adata
        The annotated data matrix.
    resolution
        For the default flavor (``'vtraag'``), you can provide a resolution
        (higher resolution means finding more and smaller clusters),
        which defaults to 1.0. See “Time as a resolution parameter” in [Lambiotte09]_.
    random_state
        Change the initialization of the optimization.
    restrict_to
        Restrict the clustering to the categories within the key for sample
        annotation, tuple needs to contain ``(obs_key, list_of_categories)``.
    key_added
        Key under which to add the cluster labels. (default: ``'louvain'``)
    adjacency
        Sparse adjacency matrix of the graph, defaults to
        ``adata.uns['neighbors']['connectivities']``.
    flavor : {``'vtraag'``, ``'igraph'``}
        Choose between to packages for computing the clustering.
        ``'vtraag'`` is much more powerful, and the default.
    directed
        Interpret the ``adjacency`` matrix as directed graph?
    use_weights
        Use weights from knn graph.
    partition_type
        Type of partition to use.
        Only a valid argument if ``flavor`` is ``'vtraag'``.
    partition_kwargs
        Key word arguments to pass to partitioning,
        if ``vtraag`` method is being used.
    copy
        Copy adata or modify it inplace.

    Returns
    -------
    :obj:`None`
        By default (``copy=False``), updates ``adata`` with the following fields:

        ``adata.obs['louvain']`` (:class:`pandas.Series`, dtype ``category``)
            Array of dim (number of samples) that stores the subgroup id
            (``'0'``, ``'1'``, ...) for each cell.

    :class:`~anndata.AnnData`
        When ``copy=True`` is set, a copy of ``adata`` with those fields is returned.
    """
    start = logg.info('running Louvain clustering')
    if (flavor != 'vtraag') and (partition_type is not None):
        raise ValueError(
            '`partition_type` is only a valid argument when `flavour` is "vtraag"'
        )
    adata = adata.copy() if copy else adata
    if adjacency is None and 'neighbors' not in adata.uns:
        raise ValueError(
            'You need to run `pp.neighbors` first to compute a neighborhood graph.'
        )
    if adjacency is None:
        adjacency = adata.uns['neighbors']['connectivities']
    if restrict_to is not None:
        restrict_key, restrict_categories = restrict_to
        adjacency, restrict_indices = restrict_adjacency(
            adata, restrict_key, restrict_categories, adjacency)
    if flavor in {'vtraag', 'igraph'}:
        if flavor == 'igraph' and resolution is not None:
            logg.warning(
                '`resolution` parameter has no effect for flavor "igraph"')
        if directed and flavor == 'igraph':
            directed = False
        if not directed: logg.debug('    using the undirected graph')
        g = utils.get_igraph_from_adjacency(adjacency, directed=directed)
        if use_weights:
            weights = np.array(g.es["weight"]).astype(np.float64)
        else:
            weights = None
        if flavor == 'vtraag':
            import louvain
            if partition_kwargs is None:
                partition_kwargs = {}
            if partition_type is None:
                partition_type = louvain.RBConfigurationVertexPartition
            if resolution is not None:
                partition_kwargs["resolution_parameter"] = resolution
            if use_weights:
                partition_kwargs["weights"] = weights
            logg.info('    using the "louvain" package of Traag (2017)')
            louvain.set_rng_seed(random_state)
            part = louvain.find_partition(
                g,
                partition_type,
                log_fname=log_fname,
                **partition_kwargs,
            )
            # adata.uns['louvain_quality'] = part.quality()
        else:
            part = g.community_multilevel(weights=weights)
        groups = np.array(part.membership)
    elif flavor == 'taynaud':
        # this is deprecated
        import networkx as nx
        import community
        g = nx.Graph(adjacency)
        partition = community.best_partition(g)
        groups = np.zeros(len(partition), dtype=int)
        for k, v in partition.items():
            groups[k] = v
    else:
        raise ValueError(
            '`flavor` needs to be "vtraag" or "igraph" or "taynaud".')
    if restrict_to is not None:
        if key_added == 'louvain':
            key_added += '_R'
        groups = rename_groups(adata, key_added, restrict_key,
                               restrict_categories, restrict_indices, groups)
    adata.obs[key_added] = pd.Categorical(
        values=groups.astype('U'),
        categories=natsorted(np.unique(groups).astype('U')),
    )
    adata.uns['louvain'] = {}
    adata.uns['louvain']['params'] = {
        'resolution': resolution,
        'random_state': random_state
    }
    logg.info(
        '    finished',
        time=start,
        #deep=(
        #    f'found {len(np.unique(groups))} clusters and added\n'
        #    f'    {key_added!r}, the cluster labels (adata.obs, categorical)'
        #),
    )
    return adata if copy else None
コード例 #29
0
def stacked_violin_t(
    adata: AnnData,
    var_names: Union[_VarNames, Mapping[str, _VarNames]],
    groupby: Optional[str] = None,
    log: bool = False,
    use_raw: Optional[bool] = None,
    num_categories: int = 7,
    figsize: Optional[Tuple[float, float]] = None,
    dendrogram: Union[bool, str] = False,
    gene_symbols: Optional[str] = None,
    var_group_positions: Optional[Sequence[Tuple[int, int]]] = None,
    var_group_labels: Optional[Sequence[str]] = None,
    standard_scale: Optional[Literal['var', 'obs']] = None,
    var_group_rotation: Optional[float] = None,
    layer: Optional[str] = None,
    stripplot: bool = False,
    jitter: Union[float, bool] = False,
    size: int = 1,
    scale: Literal['area', 'count', 'width'] = 'width',
    order: Optional[Sequence[str]] = None,
    show: Optional[bool] = None,
    save: Union[bool, str, None] = None,
    row_palette: str = 'muted',
    **kwds,
):
    """\
    Stacked violin plots.
    Makes a compact image composed of individual violin plots
    (from :func:`~seaborn.violinplot`) stacked on top of each other.
    Useful to visualize gene expression per cluster.
    Wraps :func:`seaborn.violinplot` for :class:`~anndata.AnnData`.
    Parameters
    ----------
    {common_plot_args}
    stripplot
        Add a stripplot on top of the violin plot.
        See :func:`~seaborn.stripplot`.
    jitter
        Add jitter to the stripplot (only when stripplot is True)
        See :func:`~seaborn.stripplot`.
    size
        Size of the jitter points.
    order
        Order in which to show the categories. Note: if `dendrogram=True`
        the categories order will be given by the dendrogram and `order`
        will be ignored.
    scale
        The method used to scale the width of each violin.
        If 'width' (the default), each violin will have the same width.
        If 'area', each violin will have the same area.
        If 'count', a violin’s width corresponds to the number of observations.
    row_palette
        The row palette determines the colors to use for the stacked violins.
        The value should be a valid seaborn or matplotlib palette name
        (see :func:`~seaborn.color_palette`).
        Alternatively, a single color name or hex value can be passed,
        e.g. `'red'` or `'#cc33ff'`.
    standard_scale
        Whether or not to standardize a dimension between 0 and 1,
        meaning for each variable or observation,
        subtract the minimum and divide each by its maximum.
    swap_axes
         By default, the x axis contains `var_names` (e.g. genes) and the y axis the `groupby` categories.
         By setting `swap_axes` then x are the `groupby` categories and y the `var_names`. When swapping
         axes var_group_positions are no longer used
    {show_save_ax}
    **kwds
        Are passed to :func:`~seaborn.violinplot`.
    Returns
    -------
    List of :class:`~matplotlib.axes.Axes`
    Examples
    -------
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> markers = ['C1QA', 'PSAP', 'CD79A', 'CD79B', 'CST3', 'LYZ']
    >>> sc.pl.stacked_violin(adata, markers, groupby='bulk_labels', dendrogram=True)
    Using var_names as dict:
    >>> markers = {{'T-cell': 'CD3D', 'B-cell': 'CD79A', 'myeloid': 'CST3'}}
    >>> sc.pl.stacked_violin(adata, markers, groupby='bulk_labels', dendrogram=True)
    See also
    --------
    rank_genes_groups_stacked_violin: to plot marker genes identified using the :func:`~scanpy.tl.rank_genes_groups` function.
    """
    import seaborn as sns  # Slow import, only import if called

    if use_raw is None and adata.raw is not None:
        use_raw = True
    var_names, var_group_labels, var_group_positions = check_var_names_type(
        var_names, var_group_labels, var_group_positions)
    has_var_groups = (True if var_group_positions is not None
                      and len(var_group_positions) > 0 else False)
    categories, obs_tidy = prepare_dataframe(
        adata,
        var_names,
        groupby,
        use_raw,
        log,
        num_categories,
        gene_symbols=gene_symbols,
        layer=layer,
    )

    if standard_scale == 'obs':
        obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0)
        obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0)
    elif standard_scale == 'var':
        obs_tidy -= obs_tidy.min(0)
        obs_tidy = (obs_tidy / obs_tidy.max(0)).fillna(0)
    elif standard_scale is None:
        pass
    else:
        logg.warning('Unknown type for standard_scale, ignored')

    if 'color' in kwds:
        row_palette = kwds['color']
        # remove color from kwds in case is set to avoid an error caused by
        # double parameters
        del kwds['color']
    if 'linewidth' not in kwds:
        # for the tiny violin plots used, is best
        # to use a thin lindwidth.
        kwds['linewidth'] = 0.5

    #if swap_axes:
    # plot image in which x = group by and y = var_names
    dendro_width = 3 if dendrogram else 0
    vargroups_height = 0.45 if has_var_groups else 0
    if figsize is None:
        width = len(var_names) * 0.3 + dendro_width
        height = len(categories) * 0.4 + vargroups_height
    else:
        width, height = figsize

    fig = pl.figure(figsize=(width, height))

    # define a layout of nrows = 1 x var_names columns
    # if plot dendrogram a col is added
    # each col is one violin plot.
    num_cols = len(var_names) + 1  # +1 to account for dendrogram
    width_ratios = [dendro_width] + ([1] * len(var_names))

    axs = gridspec.GridSpec(
        # 20200116 by yuanzan
        nrows=2,
        ncols=num_cols,
        height_ratios=[width - vargroups_height, vargroups_height],
        wspace=0,
        width_ratios=width_ratios,
    )

    axs_list = []
    if dendrogram:
        dendro_ax = fig.add_subplot(axs[0])
        _plot_dendrogram(
            dendro_ax,
            adata,
            groupby,
            orientation='top',
            dendrogram_key=dendrogram,
        )
        axs_list.append(dendro_ax)
    first_ax = None
    if is_color_like(row_palette):
        row_colors = [row_palette] * len(var_names)
    else:
        row_colors = sns.color_palette(row_palette, n_colors=len(var_names))
    for idx, y in enumerate(obs_tidy.columns):
        ax_idx = idx + 1  # +1 to account that idx 0 is the dendrogram
        if first_ax is None:
            ax = fig.add_subplot(axs[0, ax_idx])
            first_ax = ax
        else:
            ax = fig.add_subplot(axs[0, ax_idx])
        axs_list.append(ax)
        ax = sns.violinplot(
            x=y,
            y=obs_tidy.index,
            data=obs_tidy,
            inner=None,
            orient='horizion',
            scale=scale,
            ax=ax,
            #color=row_colors,
            **kwds,
        )

        ax.set_ylabel("")

        ax.tick_params(bottom=False, top=False, left=False, right=False)

        if idx == (len(obs_tidy.columns) - 1):
            ax.yaxis.tick_right()
        else:
            ax.set_yticklabels("")

        if stripplot:
            ax = sns.stripplot(
                # 20200116 by yuanzan
                x=obs_tidy.index,
                y=y,
                data=obs_tidy,
                jitter=jitter,
                color='black',
                size=size,
                ax=ax,
            )

        ax.set_xlabel(
            var_names[idx],
            rotation=45,
            fontsize='small',
            labelpad=8,
            ha='center',
            va='top',
        )
        ax.grid(False)
        ax.tick_params(
            axis='x',
            #right=True,
            top=False,
            #labelright=True,
            left=False,
            bottom=False,
            labelleft=False,
            labeltop=False,
            labelbottom=False,
            labelrotation=0,
            labelsize='x-small',
        )
    pl.subplots_adjust(wspace=0, hspace=0)

    _utils.savefig_or_show('stacked_violin', show=show, save=save)

    return axs_list
コード例 #30
0
def _map_names_and_colors(
    series_reference: Series,
    series_query: Series,
    colors_reference: Optional[np.array] = None,
    en_cutoff: Optional[float] = None,
) -> Union[Series, Tuple[Series, List[Any]]]:
    """
    Map annotations and colors from one series to another.

    Params
    ------
    series_reference
        Series object with categorical annotations.
    series_query
        Series for which we would like to query the category names.
    colors_reference
        If given, colors for the query categories are pulled from this color array.
    en_cutoff
        In case of a non-perfect overlap between categories of the two series,
        this decides when to label a category in the query as 'Unknown'.

    Returns
    -------
    :class:`pandas.Series`, :class:`list`
        Series with updated category names and a corresponding array of colors.
    """

    # checks: dtypes, matching indices, make sure colors match the categories
    if not is_categorical_dtype(series_reference):
        raise TypeError(
            f"Reference series must be `categorical`, found `{infer_dtype(series_reference)}`."
        )
    if not is_categorical_dtype(series_query):
        raise TypeError(
            f"Query series must be `categorical`, found `{infer_dtype(series_query)}`."
        )
    index_query, index_reference = series_query.index, series_reference.index
    if not np.all(index_reference == index_query):
        raise ValueError(
            "Series indices do not match, cannot map names/colors.")

    process_colors = colors_reference is not None
    if process_colors:
        if len(series_reference.cat.categories) != len(colors_reference):
            raise ValueError(
                "Length of reference colors does not match length of reference series."
            )
        if not all(mcolors.is_color_like(c) for c in colors_reference):
            raise ValueError("Not all colors are color-like.")

    # create dataframe to store the associations between reference and query
    cats_query = series_query.cat.categories
    cats_reference = series_reference.cat.categories
    association_df = DataFrame(None, index=cats_query, columns=cats_reference)

    # populate the dataframe - compute the overlap
    for cl in cats_query:
        row = [
            np.sum(series_reference.loc[np.array(series_query == cl)] == key)
            for key in cats_reference
        ]
        association_df.loc[cl] = row
    association_df = association_df.apply(to_numeric)

    # find the mapping which maximizes overlap and compute entropy
    names_query = association_df.T.idxmax()
    association_df["entropy"] = entropy(association_df.T)
    association_df["name"] = names_query

    # assign query colors
    if process_colors:
        colors_query = []
        for name in names_query:
            mask = cats_reference == name
            color = np.array(colors_reference)[mask][0]
            colors_query.append(color)
        association_df["color"] = colors_query

    # next, we need to make sure that we have unique names and colors. In a first step, compute how many repetitions
    # we have
    names_query_series = Series(names_query, dtype="category")
    frequ = {
        key: np.sum(names_query == key)
        for key in names_query_series.cat.categories
    }

    names_query_new = np.array(names_query.copy())
    if process_colors:
        colors_query_new = np.array(colors_query.copy())

    # Create unique names by adding suffixes "..._1, ..._2" etc and unique colors by shifting the original color
    for key, value in frequ.items():
        if value == 1:
            continue  # already unique, skip

        # deal with non-unique names
        suffix = list(np.arange(1, value + 1).astype("str"))
        unique_names = [f"{key}_{rep}" for rep in suffix]
        names_query_new[names_query_series == key] = unique_names
        if process_colors:
            color = association_df[association_df["name"] ==
                                   key]["color"].values[0]
            shifted_colors = _create_colors(color,
                                            value,
                                            saturation_range=None)
            colors_query_new[np.array(colors_query) == color] = shifted_colors

    association_df["name"] = names_query_new
    if process_colors:
        association_df["color"] = _convert_to_hex_colors(
            colors_query_new
        )  # original colors can be still there, convert to hex

    # issue a warning for mapping with high entropy
    if en_cutoff is not None:
        critical_cats = list(
            association_df.loc[association_df["entropy"] > en_cutoff,
                               "name"].values)
        if len(critical_cats) > 0:
            logg.warning(
                f"The following states could not be mapped uniquely: `{', '.join(map(str, critical_cats))}`"
            )

    return ((association_df["name"], list(association_df["color"]))
            if process_colors else association_df["name"])