Beispiel #1
0
def var_df(adata: AnnData, keys: List[str], layer: Optional[str] = None):
    """Extract layer as Pandas DataFrame indexed by features.

    Arguments
    ---------
    adata
        Annotated data matrix (reference data set).
    keys
        Observations for which to extract data.
    layer
        Name of layer to turn into a Pandas DataFrame.

    Returns
    -------
    DataFrame
        DataFrame indexed by features. Columns correspond to observations of specified
        layer.
    """

    lookup_keys = [k for k in keys if k in adata.obs_names]
    if len(lookup_keys) < len(keys):
        logg.warn(f"Keys {[k for k in keys if k not in adata.obs_names]} "
                  f"were not found in `adata.obs_names`.")

    df = pd.DataFrame(index=adata.var_names)
    for lookup_key in lookup_keys:
        df[lookup_key] = adata.var_vector(lookup_key, layer=layer)
    return df
Beispiel #2
0
def verify_dtypes(adata: AnnData) -> None:
    """Verify that AnnData object is not corrupted.

    Arguments
    ---------
    adata
        Annotated data matrix to check.

    Returns
    -------
    None
    """

    try:
        _ = adata[:, 0]
    except Exception:
        uns = adata.uns
        adata.uns = {}
        try:
            _ = adata[:, 0]
            logg.warn(
                "Safely deleted unstructured annotations (adata.uns), \n"
                "as these do not comply with permissible anndata datatypes.")
        except Exception:
            logg.warn(
                "The data might be corrupted. Please verify all annotation datatypes."
            )
            adata.uns = uns
Beispiel #3
0
def verify_neighbors(adata):
    valid = "neighbors" in adata.uns.keys() and "params" in adata.uns["neighbors"]
    if valid:
        n_neighs = (get_neighs(adata, "distances") > 0).sum(1)
        # test whether the graph is corrupted
        valid = n_neighs.min() * 2 > n_neighs.max()
    if not valid:
        logg.warn(
            "The neighbor graph has an unexpected format "
            "(e.g. computed outside scvelo) \n"
            "or is corrupted (e.g. due to subsetting). "
            "Consider recomputing with `pp.neighbors`."
        )
Beispiel #4
0
    def compute_deterministic(self, fit_offset=False, perc=None):
        subset = self._groups_for_fit
        Ms = self._Ms if subset is None else self._Ms[subset]
        Mu = self._Mu if subset is None else self._Mu[subset]

        lr = LinearRegression(fit_intercept=fit_offset, percentile=perc)
        lr.fit(Ms, Mu)
        self._offset = lr.intercept_
        self._gamma = lr.coef_

        if self._constrain_ratio is not None:
            if np.size(self._constrain_ratio) < 2:
                self._constrain_ratio = [None, self._constrain_ratio]
            cr = self._constrain_ratio
            self._gamma = np.clip(self._gamma, cr[0], cr[1])

        self._residual = self._Mu - self._gamma * self._Ms
        if fit_offset:
            self._residual -= self._offset
        _residual = self._residual

        # velocity genes
        if self._r2_adjusted:
            lr = LinearRegression(fit_intercept=fit_offset)
            lr.fit(Ms, Mu)
            _offset = lr.intercept_
            _gamma = lr.coef_

            _residual = self._Mu - _gamma * self._Ms
            if fit_offset:
                _residual -= _offset

        self._qreg_ratio = np.array(self._gamma)  # quantile regression ratio

        self._r2 = R_squared(_residual, total=self._Mu - self._Mu.mean(0))
        self._velocity_genes = ((self._r2 > self._min_r2)
                                & (self._gamma > self._min_ratio)
                                & (np.max(self._Ms > 0, 0) > 0)
                                & (np.max(self._Mu > 0, 0) > 0))

        if self._highly_variable is not None:
            self._velocity_genes &= self._highly_variable

        if np.sum(self._velocity_genes) < 2:
            min_r2 = np.percentile(self._r2, 80)
            self._velocity_genes = self._r2 > min_r2
            min_r2 = np.round(min_r2, 4)
            logg.warn(
                f"You seem to have very low signal in splicing dynamics.\n"
                f"The correlation threshold has been reduced to {min_r2}.\n"
                f"Please be cautious when interpreting results.")
Beispiel #5
0
def verify_roots(adata, roots, modality="Ms"):
    if "gene_count_corr" in adata.var.keys():
        p = get_plasticity_score(adata, modality)
        p_ub, root_ub = p > 0.5, roots > 0.9
        n_right_assignments = np.sum(root_ub * p_ub) / np.sum(p_ub)
        n_false_assignments = np.sum(root_ub * np.invert(p_ub)) / np.sum(
            np.invert(p_ub))
        n_randn_assignments = np.mean(root_ub)
        if n_right_assignments > 3 * n_randn_assignments:  # mu + 2*mu (std=mu)
            roots *= p_ub
        elif (n_false_assignments > n_randn_assignments
              or n_right_assignments < n_randn_assignments):
            logg.warn(
                "Uncertain or fuzzy root cell identification. Please verify.")
    return roots
Beispiel #6
0
def get_igraph_from_adjacency(adjacency, directed=None):
    """Get igraph graph from adjacency matrix."""
    import igraph as ig

    sources, targets = adjacency.nonzero()
    weights = adjacency[sources, targets]
    if isinstance(weights, np.matrix):
        weights = weights.A1
    g = ig.Graph(directed=directed)
    g.add_vertices(adjacency.shape[0])  # this adds adjacency.shap[0] vertices
    g.add_edges(list(zip(sources, targets)))
    try:
        g.es["weight"] = weights
    except Exception:
        pass
    if g.vcount() != adjacency.shape[0]:
        logg.warn(f"The constructed graph has only {g.vcount()} nodes. "
                  "Your adjacency matrix contained redundant nodes.")
    return g
Beispiel #7
0
def make_sparse(adata: AnnData,
                modalities: Union[List[str], str],
                inplace: bool = True) -> Optional[AnnData]:
    """Make AnnData entry sparse.

    Arguments
    ---------
    adata
        Annotated data object.
    modality
        Modality to make sparse.
    inplace
        Boolean flag to perform operations inplace or not. Defaults to `True`.

    Returns
    -------
    Optional[AnnData]
        Copy of annotated data `adata` with sparse modalities if `inplace=True`.
    """

    if not inplace:
        adata = adata.copy()

    if isinstance(modalities, str):
        modalities = [modalities]

    # Make modalities sparse
    for modality in modalities:
        count_data = get_modality(adata=adata, modality=modality)
        if modality == "X":
            logg.warn("Making `X` sparse is not supported.")
        elif not issparse(count_data):
            set_modality(adata=adata,
                         modality=modality,
                         new_value=csr_matrix(count_data))

    return adata if not inplace else None
Beispiel #8
0
def velocity_pseudotime(
    adata,
    vkey="velocity",
    groupby=None,
    groups=None,
    root_key=None,
    end_key=None,
    n_dcs=10,
    use_velocity_graph=True,
    save_diffmap=None,
    return_model=None,
    **kwargs,
):
    """Computes a pseudotime based on the velocity graph.

    Velocity pseudotime is a random-walk based distance measures on the velocity graph.
    After computing a distribution over root cells obtained from the velocity-inferred
    transition matrix, it measures the average number of steps it takes to reach a cell
    after start walking from one of the root cells. Contrarily to diffusion pseudotime,
    it implicitly infers the root cells and is based on the directed velocity graph
    instead of the similarity-based diffusion kernel.

    .. code:: python

        scv.tl.velocity_pseudotime(adata)
        scv.pl.scatter(adata, color='velocity_pseudotime', color_map='gnuplot')

    .. image:: https://user-images.githubusercontent.com/31883718/69545487-33fbc000-0f92-11ea-969b-194dc68400b0.png
       :width: 600px

    Arguments
    ---------
    adata: :class:`~anndata.AnnData`
        Annotated data matrix
    vkey: `str` (default: `'velocity'`)
        Name of velocity estimates to be used.
    groupby: `str`, `list` or `np.ndarray` (default: `None`)
        Key of observations grouping to consider.
    groups: `str`, `list` or `np.ndarray` (default: `None`)
        Groups selected to find terminal states on. Must be an element of
        adata.obs[groupby]. Only to be set, if each group is assumed to have a distinct
        lineage with an independent root and end point.
    root_key: `int` (default: `None`)
        Index of root cell to be used.
        Computed from velocity-inferred transition matrix if not specified.
    end_key: `int` (default: `None`)
        Index of end point to be used.
        Computed from velocity-inferred transition matrix if not specified.
    n_dcs: `int` (default: 10)
        The number of diffusion components to use.
    use_velocity_graph: `bool` (default: `True`)
        Whether to use the velocity graph.
        If False, it uses the similarity-based diffusion kernel.
    save_diffmap: `bool` (default: `None`)
        Whether to store diffmap coordinates.
    return_model: `bool` (default: `None`)
        Whether to return the vpt object for further inspection.
    **kwargs:
        Further arguments to pass to VPT (e.g. min_group_size, allow_kendall_tau_shift).

    Returns
    -------
    velocity_pseudotime: `.obs`
        Velocity pseudotime obtained from velocity graph.
    """  # noqa E501

    strings_to_categoricals(adata)
    if root_key is None and "root_cells" in adata.obs.keys():
        root0 = adata.obs["root_cells"][0]
        if not np.isnan(root0) and not isinstance(root0, str):
            root_key = "root_cells"
    if end_key is None and "end_points" in adata.obs.keys():
        end0 = adata.obs["end_points"][0]
        if not np.isnan(end0) and not isinstance(end0, str):
            end_key = "end_points"

    groupby = ("cell_fate" if groupby is None
               and "cell_fate" in adata.obs.keys() else groupby)
    if groupby is not None:
        logg.warn(
            "Only set groupby, when you have evident distinct clusters/lineages,"
            " each with an own root and end point.")
    categories = (adata.obs[groupby].cat.categories
                  if groupby is not None and groups is None else [None])
    for cat in categories:
        groups = cat if cat is not None else groups
        if (root_key is None or root_key in adata.obs.keys() and np.max(
                adata.obs[root_key]) == np.min(adata.obs[root_key])):
            terminal_states(adata, vkey=vkey, groupby=groupby, groups=groups)
            root_key, end_key = "root_cells", "end_points"
        cell_subset = groups_to_bool(adata, groups=groups, groupby=groupby)
        data = adata.copy(
        ) if cell_subset is None else adata[cell_subset].copy()
        if "allow_kendall_tau_shift" not in kwargs:
            kwargs["allow_kendall_tau_shift"] = True
        vpt = VPT(data, n_dcs=n_dcs, **kwargs)

        if use_velocity_graph:
            T = data.uns[f"{vkey}_graph"] - data.uns[f"{vkey}_graph_neg"]
            vpt._connectivities = T + T.T

        vpt.compute_transitions()
        vpt.compute_eigen(n_comps=n_dcs)

        vpt.set_iroot(root_key)
        vpt.compute_pseudotime()
        dpt_root = vpt.pseudotime

        if end_key is not None:
            vpt.set_iroot(end_key)
            vpt.compute_pseudotime(inverse=True)
            dpt_end = vpt.pseudotime

            # merge dpt_root and inverse dpt_end together
            vpt.pseudotime = np.nan_to_num(dpt_root) + np.nan_to_num(dpt_end)
            vpt.pseudotime[np.isfinite(dpt_root) & np.isfinite(dpt_end)] /= 2
            vpt.pseudotime = scale(vpt.pseudotime)
            vpt.pseudotime[np.isnan(dpt_root) & np.isnan(dpt_end)] = np.nan

        if "n_branchings" in kwargs and kwargs["n_branchings"] > 0:
            vpt.branchings_segments()
        else:
            vpt.indices = vpt.pseudotime.argsort()

        if f"{vkey}_pseudotime" not in adata.obs.keys():
            pseudotime = np.empty(adata.n_obs)
            pseudotime[:] = np.nan
        else:
            pseudotime = adata.obs[f"{vkey}_pseudotime"].values
        pseudotime[cell_subset] = vpt.pseudotime
        adata.obs[f"{vkey}_pseudotime"] = np.array(pseudotime,
                                                   dtype=np.float64)

        if save_diffmap:
            diffmap = np.empty(shape=(adata.n_obs, n_dcs))
            diffmap[:] = np.nan
            diffmap[cell_subset] = vpt.eigen_basis
            adata.obsm[f"X_diffmap_{groups}"] = diffmap

    return vpt if return_model else None
Beispiel #9
0
def _compute_pos(
    adjacency_solid,
    layout=None,
    random_state=0,
    init_pos=None,
    adj_tree=None,
    root=0,
    layout_kwds=None,
):
    import networkx as nx

    np.random.seed(random_state)
    random.seed(random_state)
    nx_g_solid = nx.Graph(adjacency_solid)
    if layout is None:
        layout = "fr"
    if layout == "fa":
        try:
            import fa2
        except Exception:
            logg.warn(
                "Package 'fa2' is not installed, falling back to layout 'fr'."
                "To use the faster and better ForceAtlas2 layout, "
                "install package 'fa2' (`pip install fa2`).")
            layout = "fr"
    if layout == "fa":
        init_coords = (np.random.random(
            (adjacency_solid.shape[0],
             2)) if init_pos is None else init_pos.copy())
        forceatlas2 = fa2.ForceAtlas2(
            outboundAttractionDistribution=False,
            linLogMode=False,
            adjustSizes=False,
            edgeWeightInfluence=1.0,
            jitterTolerance=1.0,
            barnesHutOptimize=True,
            barnesHutTheta=1.2,
            multiThreaded=False,
            scalingRatio=2.0,
            strongGravityMode=False,
            gravity=1.0,
            verbose=False,
        )
        iterations = (
            layout_kwds["maxiter"] if "maxiter" in layout_kwds else
            layout_kwds["iterations"] if "iterations" in layout_kwds else 500)
        pos_list = forceatlas2.forceatlas2(adjacency_solid,
                                           pos=init_coords,
                                           iterations=iterations)
        pos = {n: [p[0], -p[1]] for n, p in enumerate(pos_list)}
    elif layout == "eq_tree":
        nx_g_tree = nx.Graph(adj_tree)
        from scanpy.plotting._utils import hierarchy_pos

        pos = hierarchy_pos(nx_g_tree, root)
        if len(pos) < adjacency_solid.shape[0]:
            raise ValueError("This is a forest and not a single tree. "
                             "Try another `layout`, e.g., {'fr'}.")
    else:
        # igraph layouts
        g = get_igraph_from_adjacency(adjacency_solid)
        if "rt" in layout:
            g_tree = get_igraph_from_adjacency(adj_tree)
            root = root if isinstance(root, list) else [root]
            pos_list = g_tree.layout(layout, root=root).coords
        elif layout == "circle":
            pos_list = g.layout(layout).coords
        else:
            if init_pos is None:
                init_coords = np.random.random(
                    (adjacency_solid.shape[0], 2)).tolist()
            else:
                init_pos = init_pos.copy()
                init_pos[:, 1] *= -1  # to be checked
                init_coords = init_pos.tolist()
            try:
                layout_kwds.update({"seed": init_coords})
                pos_list = g.layout(layout, weights="weight",
                                    **layout_kwds).coords
            except AttributeError:  # hack for empty graphs...
                pos_list = g.layout(layout, **layout_kwds).coords
        pos = {n: [p[0], -p[1]] for n, p in enumerate(pos_list)}
    if len(pos) == 1:
        pos[0] = (0.5, 0.5)
    pos_array = np.array([pos[n] for count, n in enumerate(nx_g_solid)])
    return pos_array
Beispiel #10
0
def parallelize(
    callback: Callable[[Any], Any],
    collection: Union[spmatrix, Sequence[Any]],
    n_jobs: Optional[int] = None,
    n_split: Optional[int] = None,
    unit: str = "",
    as_array: bool = True,
    use_ixs: bool = False,
    backend: str = "loky",
    extractor: Optional[Callable[[Any], Any]] = None,
    show_progress_bar: bool = True,
) -> Union[np.ndarray, Any]:
    """
    Parallelize function call over a collection of elements.

    Parameters
    ----------
    callback
        Function to parallelize.
    collection
        Sequence of items which to chunkify.
    n_jobs
        Number of parallel jobs.
    n_split
        Split :paramref:`collection` into :paramref:`n_split` chunks.
        If `None`, split into :paramref:`n_jobs` chunks.
    unit
        Unit of the progress bar.
    as_array
        Whether to convert the results not :class:`numpy.ndarray`.
    use_ixs
        Whether to pass indices to the callback.
    backend
        Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid
        options.
    extractor
        Function to apply to the result after all jobs have finished.
    show_progress_bar
        Whether to show a progress bar.

    Returns
    -------
    :class:`numpy.ndarray`
        Result depending on :paramref:`extractor` and :paramref:`as_array`.
    """

    if show_progress_bar:
        try:
            try:
                from tqdm.notebook import tqdm
            except ImportError:
                from tqdm import tqdm_notebook as tqdm
            import ipywidgets  # noqa
        except ImportError:
            global _msg_shown
            tqdm = None

            if not _msg_shown:
                logg.warn(
                    "Unable to create progress bar. "
                    "Consider installing `tqdm` as `pip install tqdm` "
                    "and `ipywidgets` as `pip install ipywidgets`,\n"
                    "or disable the progress bar using `show_progress_bar=False`."
                )
                _msg_shown = True
    else:
        tqdm = None

    def update(pbar, queue, n_total):
        n_finished = 0
        while n_finished < n_total:
            try:
                res = queue.get()
            except EOFError as e:
                if not n_finished != n_total:
                    raise RuntimeError(
                        f"Finished only `{n_finished} out of `{n_total}` tasks.`"
                    ) from e
                break
            assert res in (None, (1, None), 1)  # (None, 1) means only 1 job
            if res == (1, None):
                n_finished += 1
                if pbar is not None:
                    pbar.update()
            elif res is None:
                n_finished += 1
            elif pbar is not None:
                pbar.update()

        if pbar is not None:
            pbar.close()

    def wrapper(*args, **kwargs):
        if pass_queue and show_progress_bar:
            pbar = None if tqdm is None else tqdm(total=col_len, unit=unit)
            queue = Manager().Queue()
            thread = Thread(target=update,
                            args=(pbar, queue, len(collections)))
            thread.start()
        else:
            pbar, queue, thread = None, None, None

        res = Parallel(n_jobs=n_jobs, backend=backend)(delayed(callback)(
            *((i, cs) if use_ixs else (cs, )),
            *args,
            **kwargs,
            queue=queue,
        ) for i, cs in enumerate(collections))

        res = np.array(res) if as_array else res
        if thread is not None:
            thread.join()

        return res if extractor is None else extractor(res)

    col_len = collection.shape[0] if issparse(collection) else len(collection)

    if n_split is None:
        n_split = get_n_jobs(n_jobs=n_jobs)

    if issparse(collection):
        if n_split == collection.shape[0]:
            collections = [
                collection[[ix], :] for ix in range(collection.shape[0])
            ]
        else:
            step = collection.shape[0] // n_split

            ixs = [
                np.arange(i * step, min((i + 1) * step, collection.shape[0]))
                for i in range(n_split)
            ]
            ixs[-1] = np.append(
                ixs[-1], np.arange(ixs[-1][-1] + 1, collection.shape[0]))

            collections = [collection[ix, :] for ix in filter(len, ixs)]
    else:
        collections = list(filter(len, np.array_split(collection, n_split)))

    pass_queue = not hasattr(callback,
                             "py_func")  # we'd be inside a numba function

    return wrapper
Beispiel #11
0
def velocity_graph(
    data,
    vkey="velocity",
    xkey="Ms",
    tkey=None,
    basis=None,
    n_neighbors=None,
    n_recurse_neighbors=None,
    random_neighbors_at_max=None,
    sqrt_transform=None,
    variance_stabilization=None,
    gene_subset=None,
    compute_uncertainties=None,
    approx=None,
    mode_neighbors="distances",
    copy=False,
    n_jobs=None,
    backend="loky",
):
    """Computes velocity graph based on cosine similarities.

    The cosine similarities are computed between velocities and potential cell state
    transitions, i.e. it measures how well a corresponding change in gene expression
    :math:`\\delta_{ij} = x_j - x_i` matches the predicted change according to the
    velocity vector :math:`\\nu_i`,

    .. math::
        \\pi_{ij} = \\cos\\angle(\\delta_{ij}, \\nu_i)
        = \\frac{\\delta_{ij}^T \\nu_i}{\\left\\lVert\\delta_{ij}\\right\\rVert
        \\left\\lVert \\nu_i \\right\\rVert}.

    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    vkey: `str` (default: `'velocity'`)
        Name of velocity estimates to be used.
    xkey: `str` (default: `'Ms'`)
        Layer key to extract count data from.
    tkey: `str` (default: `None`)
        Observation key to extract time data from.
    basis: `str` (default: `None`)
        Basis / Embedding to use.
    n_neighbors: `int` or `None` (default: None)
        Use fixed number of neighbors or do recursive neighbor search (if `None`).
    n_recurse_neighbors: `int` (default: `None`)
        Number of recursions for neighbors search. Defaults to
        2 if mode_neighbors is 'distances', and 1 if mode_neighbors is 'connectivities'.
    random_neighbors_at_max: `int` or `None` (default: `None`)
        If number of iterative neighbors for an individual cell is higher than this
        threshold, a random selection of such are chosen as reference neighbors.
    sqrt_transform: `bool` (default: `False`)
        Whether to variance-transform the cell states changes
        and velocities before computing cosine similarities.
    gene_subset: `list` of `str`, subset of adata.var_names or `None`(default: `None`)
        Subset of genes to compute velocity graph on exclusively.
    compute_uncertainties: `bool` (default: `None`)
        Whether to compute uncertainties along with cosine correlation.
    approx: `bool` or `None` (default: `None`)
        If True, first 30 pc's are used instead of the full count matrix
    mode_neighbors: 'str' (default: `'distances'`)
        Determines the type of KNN graph used. Options are 'distances' or
        'connectivities'. The latter yields a symmetric graph.
    copy: `bool` (default: `False`)
        Return a copy instead of writing to adata.
    n_jobs: `int` or `None` (default: `None`)
        Number of parallel jobs.
    backend: `str` (default: "loky")
        Backend used for multiprocessing. See :class:`joblib.Parallel` for valid
        options.

    Returns
    -------
    velocity_graph: `.uns`
        sparse matrix with correlations of cell state transitions with velocities
    """

    adata = data.copy() if copy else data
    verify_neighbors(adata)
    if vkey not in adata.layers.keys():
        velocity(adata, vkey=vkey)
    if sqrt_transform is None:
        sqrt_transform = variance_stabilization

    vgraph = VelocityGraph(
        adata,
        vkey=vkey,
        xkey=xkey,
        tkey=tkey,
        basis=basis,
        n_neighbors=n_neighbors,
        approx=approx,
        n_recurse_neighbors=n_recurse_neighbors,
        random_neighbors_at_max=random_neighbors_at_max,
        sqrt_transform=sqrt_transform,
        gene_subset=gene_subset,
        compute_uncertainties=compute_uncertainties,
        report=True,
        mode_neighbors=mode_neighbors,
    )

    if isinstance(basis, str):
        logg.warn(
            f"The velocity graph is computed on {basis} embedding coordinates.\n"
            f"        Consider computing the graph in an unbiased manner \n"
            f"        on full expression space by not specifying basis.\n")

    n_jobs = get_n_jobs(n_jobs=n_jobs)
    logg.info(
        f"computing velocity graph (using {n_jobs}/{os.cpu_count()} cores)",
        r=True)
    vgraph.compute_cosines(n_jobs=n_jobs, backend=backend)

    adata.uns[f"{vkey}_graph"] = vgraph.graph
    adata.uns[f"{vkey}_graph_neg"] = vgraph.graph_neg

    if vgraph.uncertainties is not None:
        adata.uns[f"{vkey}_graph_uncertainties"] = vgraph.uncertainties

    adata.obs[f"{vkey}_self_transition"] = vgraph.self_prob

    if f"{vkey}_params" in adata.uns.keys():
        if "embeddings" in adata.uns[f"{vkey}_params"]:
            del adata.uns[f"{vkey}_params"]["embeddings"]
    else:
        adata.uns[f"{vkey}_params"] = {}
    adata.uns[f"{vkey}_params"]["mode_neighbors"] = mode_neighbors
    adata.uns[f"{vkey}_params"][
        "n_recurse_neighbors"] = vgraph.n_recurse_neighbors

    logg.info("    finished",
              time=True,
              end=" " if settings.verbosity > 2 else "\n")
    logg.hint(
        "added \n"
        f"    '{vkey}_graph', sparse matrix with cosine correlations (adata.uns)"
    )

    return adata if copy else None
Beispiel #12
0
def terminal_states(
    data,
    vkey="velocity",
    modality="Ms",
    groupby=None,
    groups=None,
    self_transitions=False,
    eps=1e-3,
    random_state=0,
    copy=False,
    **kwargs,
):
    """Computes terminal states (root and end points).

    The end points and root cells are obtained as stationary states of the
    velocity-inferred transition matrix and its transposed, respectively,
    which is given by left eigenvectors corresponding to an eigenvalue of 1, i.e.

    .. math::
        μ^{\\textrm{end}}=μ^{\\textrm{end}} \\pi, \\quad
        μ^{\\textrm{root}}=μ^{\\textrm{root}} \\pi^{\\small \\textrm{T}}.

    .. code:: python

        scv.tl.terminal_states(adata)
        scv.pl.scatter(adata, color=['root_cells', 'end_points'])

    .. image:: https://user-images.githubusercontent.com/31883718/69496183-bcfdf300-0ecf-11ea-9aae-685300a0b1ba.png

    Alternatively, we recommend to use :func:`cellrank.tl.terminal_states`
    providing an improved/generalized approach of identifying terminal states.

    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    vkey: `str` (default: `'velocity'`)
        Name of velocity estimates to be used.
    modality: `str` (default: `'Ms'`)
        Layer used to calculate terminal states.
    groupby: `str`, `list` or `np.ndarray` (default: `None`)
        Key of observations grouping to consider. Only to be set, if each group is
        assumed to have a distinct lineage with an independent root and end point.
    groups: `str`, `list` or `np.ndarray` (default: `None`)
        Groups selected to find terminal states on. Must be an element of .obs[groupby].
        To be specified only for very distinct/disconnected clusters.
    self_transitions: `bool` (default: `False`)
        Allow transitions from one node to itself.
    eps: `float` (default: 1e-3)
        Tolerance for eigenvalue selection.
    random_state: `int` or None (default: 0)
        Seed used by the random number generator.
        If `None`, use the `RandomState` instance by `np.random`.
    copy: `bool` (default: `False`)
        Return a copy instead of writing to data.
    **kwargs:
        Passed to scvelo.tl.transition_matrix(), e.g. basis, weight_diffusion.

    Returns
    -------
    root_cells: `.obs`
        sparse matrix with transition probabilities.
    end_points: `.obs`
        sparse matrix with transition probabilities.
    """  # noqa E501

    adata = data.copy() if copy else data
    verify_neighbors(adata)

    logg.info("computing terminal states", r=True)

    strings_to_categoricals(adata)
    if groupby is not None:
        logg.warn(
            "Only set groupby, when you have evident distinct clusters/lineages,"
            " each with an own root and end point.")

    kwargs.update({"self_transitions": self_transitions})
    categories = [None]
    if groupby is not None and groups is None:
        categories = adata.obs[groupby].cat.categories
    for cat in categories:
        groups = cat if cat is not None else groups
        cell_subset = groups_to_bool(adata, groups=groups, groupby=groupby)
        _adata = adata if groups is None else adata[cell_subset]
        connectivities = get_connectivities(_adata, "distances")

        T = transition_matrix(_adata, vkey=vkey, backward=True, **kwargs)
        eigvecs_roots = eigs(T,
                             eps=eps,
                             perc=[2, 98],
                             random_state=random_state)[1]
        roots = csr_matrix.dot(connectivities, eigvecs_roots).sum(1)
        roots = scale(np.clip(roots, 0, np.percentile(roots, 98)))
        roots = verify_roots(_adata, roots, modality)
        write_to_obs(adata, "root_cells", roots, cell_subset)

        T = transition_matrix(_adata, vkey=vkey, backward=False, **kwargs)
        eigvecs_ends = eigs(T,
                            eps=eps,
                            perc=[2, 98],
                            random_state=random_state)[1]
        ends = csr_matrix.dot(connectivities, eigvecs_ends).sum(1)
        ends = scale(np.clip(ends, 0, np.percentile(ends, 98)))
        write_to_obs(adata, "end_points", ends, cell_subset)

        n_roots, n_ends = eigvecs_roots.shape[1], eigvecs_ends.shape[1]
        groups_str = f" ({groups})" if isinstance(groups, str) else ""
        roots_str = f"{n_roots} {'regions' if n_roots > 1 else 'region'}"
        ends_str = f"{n_ends} {'regions' if n_ends > 1 else 'region'}"

        logg.info(f"    identified {roots_str} of root cells "
                  f"and {ends_str} of end points {groups_str}.")

    logg.info("    finished",
              time=True,
              end=" " if settings.verbosity > 2 else "\n")
    logg.hint(
        "added\n"
        "    'root_cells', root cells of Markov diffusion process (adata.obs)\n"
        "    'end_points', end points of Markov diffusion process (adata.obs)")
    return adata if copy else None
Beispiel #13
0
def velocity_embedding(
    data,
    basis=None,
    vkey="velocity",
    scale=10,
    self_transitions=True,
    use_negative_cosines=True,
    direct_pca_projection=None,
    retain_scale=False,
    autoscale=True,
    all_comps=True,
    T=None,
    copy=False,
):
    """Projects the single cell velocities into any embedding.

    Given normalized difference of the embedding positions
    :math:
    `\\tilde \\delta_{ij} = \\frac{x_j-x_i}{\\left\\lVert x_j-x_i \\right\\rVert}`.
    the projections are obtained as expected displacements with respect to the
    transition matrix :math:`\\tilde \\pi_{ij}` as

    .. math::
        \\tilde \\nu_i = E_{\\tilde \\pi_{i\\cdot}} [\\tilde \\delta_{i \\cdot}]
        = \\sum_{j \\neq i} \\left( \\tilde \\pi_{ij} - \\frac1n \\right) \\tilde \\
        delta_{ij}.


    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    basis: `str` (default: `'tsne'`)
        Which embedding to use.
    vkey: `str` (default: `'velocity'`)
        Name of velocity estimates to be used.
    scale: `int` (default: 10)
        Scale parameter of gaussian kernel for transition matrix.
    self_transitions: `bool` (default: `True`)
        Whether to allow self transitions, based on the confidences of transitioning to
        neighboring cells.
    use_negative_cosines: `bool` (default: `True`)
        Whether to project cell-to-cell transitions with negative cosines into
        negative/opposite direction.
    direct_pca_projection: `bool` (default: `None`)
        Whether to directly project the velocities into PCA space,
        thus skipping the velocity graph.
    retain_scale: `bool` (default: `False`)
        Whether to retain scale from high dimensional space in embedding.
    autoscale: `bool` (default: `True`)
        Whether to scale the embedded velocities by a scalar multiplier,
        which simply ensures that the arrows in the embedding are properly scaled.
    all_comps: `bool` (default: `True`)
        Whether to compute the velocities on all embedding components.
    T: `csr_matrix` (default: `None`)
        Allows the user to directly pass a transition matrix.
    copy: `bool` (default: `False`)
        Return a copy instead of writing to `adata`.

    Returns
    -------
    velocity_umap: `.obsm`
        coordinates of velocity projection on embedding (e.g., basis='umap')
    """

    adata = data.copy() if copy else data

    if basis is None:
        keys = [
            key for key in ["pca", "tsne", "umap"] if f"X_{key}" in adata.obsm.keys()
        ]
        if len(keys) > 0:
            basis = "pca" if direct_pca_projection else keys[-1]
        else:
            raise ValueError("No basis specified")

    if f"X_{basis}" not in adata.obsm_keys():
        raise ValueError("You need to compute the embedding first.")

    if direct_pca_projection and "pca" in basis:
        logg.warn(
            "Directly projecting velocities into PCA space is for exploratory analysis "
            "on principal components.\n"
            "         It does not reflect the actual velocity field from high "
            "dimensional gene expression space.\n"
            "         To visualize velocities, consider applying "
            "`direct_pca_projection=False`.\n"
        )

    logg.info("computing velocity embedding", r=True)

    V = np.array(adata.layers[vkey])
    vgenes = np.ones(adata.n_vars, dtype=bool)
    if f"{vkey}_genes" in adata.var.keys():
        vgenes &= np.array(adata.var[f"{vkey}_genes"], dtype=bool)
    vgenes &= ~np.isnan(V.sum(0))
    V = V[:, vgenes]

    if direct_pca_projection and "pca" in basis:
        PCs = adata.varm["PCs"] if all_comps else adata.varm["PCs"][:, :2]
        PCs = PCs[vgenes]

        X_emb = adata.obsm[f"X_{basis}"]
        V_emb = (V - V.mean(0)).dot(PCs)

    else:
        X_emb = (
            adata.obsm[f"X_{basis}"] if all_comps else adata.obsm[f"X_{basis}"][:, :2]
        )
        V_emb = np.zeros(X_emb.shape)

        T = (
            transition_matrix(
                adata,
                vkey=vkey,
                scale=scale,
                self_transitions=self_transitions,
                use_negative_cosines=use_negative_cosines,
            )
            if T is None
            else T
        )
        T.setdiag(0)
        T.eliminate_zeros()

        densify = adata.n_obs < 1e4
        TA = T.A if densify else None

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(adata.n_obs):
                indices = T[i].indices
                dX = X_emb[indices] - X_emb[i, None]  # shape (n_neighbors, 2)
                if not retain_scale:
                    dX /= l2_norm(dX)[:, None]
                dX[np.isnan(dX)] = 0  # zero diff in a steady-state
                probs = TA[i, indices] if densify else T[i].data
                V_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0)

        if retain_scale:
            X = (
                adata.layers["Ms"]
                if "Ms" in adata.layers.keys()
                else adata.layers["spliced"]
            )
            delta = T.dot(X[:, vgenes]) - X[:, vgenes]
            if issparse(delta):
                delta = delta.A
            cos_proj = (V * delta).sum(1) / l2_norm(delta)
            V_emb *= np.clip(cos_proj[:, None] * 10, 0, 1)

    if autoscale:
        V_emb /= 3 * quiver_autoscale(X_emb, V_emb)

    if f"{vkey}_params" in adata.uns.keys():
        adata.uns[f"{vkey}_params"]["embeddings"] = (
            []
            if "embeddings" not in adata.uns[f"{vkey}_params"]
            else list(adata.uns[f"{vkey}_params"]["embeddings"])
        )
        adata.uns[f"{vkey}_params"]["embeddings"].extend([basis])

    vkey += f"_{basis}"
    adata.obsm[vkey] = V_emb

    logg.info("    finished", time=True, end=" " if settings.verbosity > 2 else "\n")
    logg.hint("added\n" f"    '{vkey}', embedded velocity vectors (adata.obsm)")

    return adata if copy else None
Beispiel #14
0
def velocity_genes(
    data,
    vkey="velocity",
    min_r2=0.01,
    min_ratio=0.01,
    use_highly_variable=True,
    copy=False,
):
    """Estimates velocities in a gene-specific manner

    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    vkey: `str` (default: `'velocity'`)
        Name under which to refer to the computed velocities.
    min_r2: `float` (default: 0.01)
        Minimum threshold for coefficient of determination
    min_ratio: `float` (default: 0.01)
        Minimum threshold for quantile regression un/spliced ratio.
    use_highly_variable: `bool` (default: True)
        Whether to use highly variable genes only, stored in .var['highly_variable'].
    copy: `bool` (default: `False`)
        Return a copy instead of writing to `adata`.

    Returns
    -------
    Updates `adata` attributes
    velocity_genes: `.var`
        genes to be used for further velocity analysis (velocity graph and embedding)
    """

    adata = data.copy() if copy else data
    if f"{vkey}_genes" not in adata.var.keys():
        velocity(adata, vkey)
    vgenes = np.ones(adata.n_vars, dtype=bool)

    if "Ms" in adata.layers.keys() and "Mu" in adata.layers.keys():
        vgenes &= np.max(adata.layers["Ms"] > 0, 0) > 0
        vgenes &= np.max(adata.layers["Mu"] > 0, 0) > 0

    if min_r2 is not None and f"{vkey}_r2" in adata.var.keys():
        vgenes &= adata.var[f"{vkey}_r2"] > min_r2

    if min_ratio is not None and f"{vkey}_qreg_ratio" in adata.var.keys():
        vgenes &= adata.var[f"{vkey}_qreg_ratio"] > min_ratio

    if use_highly_variable and "highly_variable" in adata.var.keys():
        vgenes &= adata.var["highly_variable"].values

    if np.sum(vgenes) < 2:
        logg.warn(
            "You seem to have very low signal in splicing dynamics.\n"
            "Consider reducing the thresholds and be cautious with interpretations.\n"
        )

    adata.var[f"{vkey}_genes"] = vgenes

    logg.info("Number of obtained velocity_genes:",
              np.sum(adata.var[f"{vkey}_genes"]))

    return adata if copy else None
Beispiel #15
0
def velocity(
    data,
    vkey="velocity",
    mode="stochastic",
    fit_offset=False,
    fit_offset2=False,
    filter_genes=False,
    groups=None,
    groupby=None,
    groups_for_fit=None,
    constrain_ratio=None,
    use_raw=False,
    use_latent_time=None,
    perc=[5, 95],
    min_r2=1e-2,
    min_likelihood=1e-3,
    r2_adjusted=None,
    use_highly_variable=True,
    diff_kinetics=None,
    copy=False,
    **kwargs,
):
    """Estimates velocities in a gene-specific manner.

    The steady-state model [Manno18]_ determines velocities by quantifying how
    observations deviate from a presumed steady-state equilibrium ratio of unspliced to
    spliced mRNA levels. This steady-state ratio is obtained by performing a linear
    regression restricting the input data to the extreme quantiles. By including
    second-order moments, the stochastic model [Bergen19]_ exploits not only the balance
    of unspliced to spliced mRNA levels but also their covariation. By contrast, the
    likelihood-based dynamical model [Bergen19]_ solves the full splicing kinetics and
    generalizes RNA velocity estimation to transient systems. It is also capable of
    capturing non-observed steady states.

    .. image:: https://user-images.githubusercontent.com/31883718/69636491-ff057100-1056-11ea-90b7-d04098112ce1.png

    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    vkey: `str` (default: `'velocity'`)
        Name under which to refer to the computed velocities
        for `velocity_graph` and `velocity_embedding`.
    mode: `'deterministic'`, `'stochastic'` or `'dynamical'` (default: `'stochastic'`)
        Whether to run the estimation using the steady-state/deterministic,
        stochastic or dynamical model of transcriptional dynamics.
        The dynamical model requires to run `tl.recover_dynamics` first.
    fit_offset: `bool` (default: `False`)
        Whether to fit with offset for first order moment dynamics.
    fit_offset2: `bool`, (default: `False`)
        Whether to fit with offset for second order moment dynamics.
    filter_genes: `bool` (default: `True`)
        Whether to remove genes that are not used for further velocity analysis.
    groups: `str`, `list` (default: `None`)
        Subset of groups, e.g. [‘g1’, ‘g2’, ‘g3’],
        to which velocity analysis shall be restricted.
    groupby: `str`, `list` or `np.ndarray` (default: `None`)
        Key of observations grouping to consider.
    groups_for_fit: `str`, `list` or `np.ndarray` (default: `None`)
        Subset of groups, e.g. [‘g1’, ‘g2’, ‘g3’],
        to which steady-state fitting shall be restricted.
    constrain_ratio: `float` or tuple of type `float` or None: (default: `None`)
        Bounds for the steady-state ratio.
    use_raw: `bool` (default: `False`)
        Whether to use raw data for estimation.
    use_latent_time: `bool`or `None` (default: `None`)
        Whether to use latent time as a regularization for velocity estimation.
    perc: `float` (default: `[5, 95]`)
        Percentile, e.g. 98, for extreme quantile fit.
    min_r2: `float` (default: 0.01)
        Minimum threshold for coefficient of determination
    min_likelihood: `float` (default: `None`)
        Minimal likelihood for velocity genes to fit the model on.
    r2_adjusted: `bool` (default: `None`)
        Whether to compute coefficient of determination
        on full data fit (adjusted) or extreme quantile fit (None)
    use_highly_variable: `bool` (default: True)
        Whether to use highly variable genes only, stored in .var['highly_variable'].
    copy: `bool` (default: `False`)
        Return a copy instead of writing to `adata`.

    Returns
    -------
    velocity: `.layers`
        velocity vectors for each individual cell
    velocity_genes, velocity_beta, velocity_gamma, velocity_r2: `.var`
        parameters
    """  # noqa E501

    adata = data.copy() if copy else data
    if not use_raw and "Ms" not in adata.layers.keys():
        moments(adata)

    logg.info("computing velocities", r=True)

    strings_to_categoricals(adata)

    if mode is None or (mode == "dynamical"
                        and "fit_alpha" not in adata.var.keys()):
        mode = "stochastic"
        logg.warn("Falling back to stochastic model. "
                  "For the dynamical model run tl.recover_dynamics first.")

    if mode in {"dynamical", "dynamical_residuals"}:
        from .dynamical_model_utils import get_divergence, get_reads, get_vars

        gene_subset = ~np.isnan(adata.var["fit_alpha"].values)
        vdata = adata[:, gene_subset]
        alpha, beta, gamma, scaling, t_ = get_vars(vdata)

        connect = not adata.uns["recover_dynamics"]["use_raw"]
        kwargs_ = {
            "kernel_width": None,
            "normalized": True,
            "var_scale": True,
            "reg_par": None,
            "min_confidence": 1e-2,
            "constraint_time_increments": False,
            "fit_steady_states": True,
            "fit_basal_transcription": None,
            "use_connectivities": connect,
            "time_connectivities": connect,
            "use_latent_time": use_latent_time,
        }
        kwargs_.update(adata.uns["recover_dynamics"])
        kwargs_.update(**kwargs)

        if "residuals" in mode:
            u, s = get_reads(vdata,
                             use_raw=adata.uns["recover_dynamics"]["use_raw"])
            if kwargs_["fit_basal_transcription"]:
                u, s = u - adata.var["fit_u0"], s - adata.var["fit_s0"]
            o = vdata.layers["fit_t"] < t_
            vt = u * beta - s * gamma  # ds/dt
            wt = (alpha * o - beta * u) * scaling  # du/dt
        else:
            vt, wt = get_divergence(vdata, mode="velocity", **kwargs_)

        vgenes = adata.var.fit_likelihood > min_likelihood
        if min_r2 is not None:
            if "fit_r2" not in adata.var.keys():
                velo = Velocity(
                    adata,
                    groups_for_fit=groups_for_fit,
                    groupby=groupby,
                    constrain_ratio=constrain_ratio,
                    min_r2=min_r2,
                    use_highly_variable=use_highly_variable,
                    use_raw=use_raw,
                )
                velo.compute_deterministic(fit_offset=fit_offset, perc=perc)
                adata.var["fit_r2"] = velo._r2
            vgenes &= adata.var.fit_r2 > min_r2

        lb, ub = np.nanpercentile(adata.var.fit_scaling, [10, 90])
        vgenes = (vgenes
                  & (adata.var.fit_scaling > np.min([lb, 0.03]))
                  & (adata.var.fit_scaling < np.max([ub, 3])))

        adata.var[f"{vkey}_genes"] = vgenes

        adata.layers[vkey] = np.ones(adata.shape) * np.nan
        adata.layers[vkey][:, gene_subset] = vt

        adata.layers[f"{vkey}_u"] = np.ones(adata.shape) * np.nan
        adata.layers[f"{vkey}_u"][:, gene_subset] = wt

        if filter_genes and len(set(vgenes)) > 1:
            adata._inplace_subset_var(vgenes)

    elif mode in {"steady_state", "deterministic", "stochastic"}:
        categories = (adata.obs[groupby].cat.categories
                      if groupby is not None and groups is None
                      and groups_for_fit is None else [None])

        for cat in categories:
            groups = cat if cat is not None else groups

            cell_subset = groups_to_bool(adata, groups, groupby)
            _adata = adata if groups is None else adata[cell_subset]
            velo = Velocity(
                _adata,
                groups_for_fit=groups_for_fit,
                groupby=groupby,
                constrain_ratio=constrain_ratio,
                min_r2=min_r2,
                r2_adjusted=r2_adjusted,
                use_highly_variable=use_highly_variable,
                use_raw=use_raw,
            )
            velo.compute_deterministic(fit_offset=fit_offset, perc=perc)

            if mode == "stochastic":
                if filter_genes and len(set(velo._velocity_genes)) > 1:
                    adata._inplace_subset_var(velo._velocity_genes)
                    residual = velo._residual[:, velo._velocity_genes]
                    _adata = adata if groups is None else adata[cell_subset]
                    velo = Velocity(
                        _adata,
                        residual=residual,
                        groups_for_fit=groups_for_fit,
                        groupby=groupby,
                        constrain_ratio=constrain_ratio,
                        use_highly_variable=use_highly_variable,
                    )
                velo.compute_stochastic(fit_offset,
                                        fit_offset2,
                                        mode,
                                        perc=perc)

            write_residuals(adata, vkey, velo._residual, cell_subset)
            write_residuals(adata, f"variance_{vkey}", velo._residual2,
                            cell_subset)
            write_pars(adata,
                       vkey,
                       velo.get_pars(),
                       velo.get_pars_names(),
                       add_key=cat)

            if filter_genes and len(set(velo._velocity_genes)) > 1:
                adata._inplace_subset_var(velo._velocity_genes)

    else:
        raise ValueError(
            "Mode can only be one of these: deterministic, stochastic or dynamical."
        )

    if f"{vkey}_genes" in adata.var.keys() and np.sum(
            adata.var[f"{vkey}_genes"]) < 10:
        logg.warn(
            "Too few genes are selected as velocity genes. "
            "Consider setting a lower threshold for min_r2 or min_likelihood.")

    if diff_kinetics:
        if not isinstance(diff_kinetics, str):
            diff_kinetics = "fit_diff_kinetics"
        if diff_kinetics in adata.var.keys():
            if diff_kinetics in adata.uns["recover_dynamics"]:
                groupby = adata.uns["recover_dynamics"]["fit_diff_kinetics"]
            else:
                groupby = "clusters"
            clusters = adata.obs[groupby]
            for i, v in enumerate(
                    np.array(adata.var[diff_kinetics].values, dtype=str)):
                if len(v) > 0 and v != "nan":
                    idx = 1 - clusters.isin([a.strip() for a in v.split(",")])
                    adata.layers[vkey][:, i] *= idx
                    if mode == "dynamical":
                        adata.layers[f"{vkey}_u"][:, i] *= idx

    adata.uns[f"{vkey}_params"] = {
        "mode": mode,
        "fit_offset": fit_offset,
        "perc": perc
    }

    logg.info("    finished",
              time=True,
              end=" " if settings.verbosity > 2 else "\n")
    logg.hint(
        "added \n"
        f"    '{vkey}', velocity vectors for each individual cell (adata.layers)"
    )

    return adata if copy else None
Beispiel #16
0
def normalize_per_cell(
    data,
    counts_per_cell_after=None,
    counts_per_cell=None,
    key_n_counts=None,
    max_proportion_per_cell=None,
    use_initial_size=True,
    layers=None,
    enforce=None,
    copy=False,
):
    """Normalize each cell by total counts over all genes.

    Parameters
    ----------
    data : :class:`~anndata.AnnData`, `np.ndarray`, `sp.sparse`
        The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond
        to cells and columns to genes.
    counts_per_cell_after : `float` or `None`, optional (default: `None`)
        If `None`, after normalization, each cell has a total count equal
        to the median of the *counts_per_cell* before normalization.
    counts_per_cell : `np.array`, optional (default: `None`)
        Precomputed counts per cell.
    key_n_counts : `str`, optional (default: `'n_counts'`)
        Name of the field in `adata.obs` where the total counts per cell are
        stored.
    max_proportion_per_cell : `int` (default: `None`)
        Exclude genes counts that account for more than
        a specific proportion of cell size, e.g. 0.05.
    use_initial_size : `bool` (default: `True`)
        Whether to use initial cell sizes oder actual cell sizes.
    layers : `str` or `list` (default: `['spliced', 'unspliced']`)
        Keys for layers to be also considered for normalization.
    copy : `bool`, optional (default: `False`)
        If an :class:`~anndata.AnnData` is passed, determines whether a copy
        is returned.

    Returns
    -------
    Returns or updates `adata` with normalized counts.
    """

    adata = data.copy() if copy else data
    if layers is None:
        layers = ["spliced", "unspliced"]
    elif layers == "all":
        layers = adata.layers.keys()
    elif isinstance(layers, str):
        layers = [layers]
    layers = ["X"
              ] + [layer for layer in layers if layer in adata.layers.keys()]
    modified_layers = []

    if isinstance(counts_per_cell, str):
        if counts_per_cell not in adata.obs.keys():
            _set_initial_size(adata, layers)
        counts_per_cell = (adata.obs[counts_per_cell].values
                           if counts_per_cell in adata.obs.keys() else None)

    for layer in layers:
        check_if_valid_dtype(adata, layer)
        X = adata.X if layer == "X" else adata.layers[layer]

        if not_yet_normalized(X) or enforce:
            counts = (counts_per_cell if counts_per_cell is not None else
                      _get_initial_size(adata, layer)
                      if use_initial_size else _get_size(adata, layer))
            if max_proportion_per_cell is not None and (
                    0 < max_proportion_per_cell < 1):
                counts = counts_per_cell_quantile(X, max_proportion_per_cell,
                                                  counts)
            # equivalent to sc.pp.normalize_per_cell(X, counts_per_cell_after, counts)
            counts_after = (np.median(counts) if counts_per_cell_after is None
                            else counts_per_cell_after)

            counts_after += counts_after == 0
            counts = counts / counts_after
            counts += counts == 0  # to avoid division by zero

            if issparse(X):
                sparsefuncs.inplace_row_scale(X, 1 / counts)
            else:
                X /= np.array(counts[:, None])
            modified_layers.append(layer)
            if (layer == "X" and "gene_count_corr" not in adata.var.keys()
                    and X.shape[-1] > 3e3):
                try:
                    adata.var["gene_count_corr"] = np.round(
                        csr_vcorrcoef(X.T, np.ravel((X > 0).sum(1))), 4)
                except Exception:
                    pass
        else:
            logg.warn(
                f"Did not normalize {layer} as it looks processed already. "
                "To enforce normalization, set `enforce=True`.")

    adata.obs["n_counts"
              if key_n_counts is None else key_n_counts] = _get_size(adata)
    if len(modified_layers) > 0:
        logg.info("Normalized count data:", f"{', '.join(modified_layers)}.")

    return adata if copy else None
Beispiel #17
0
def _paga(
    adata,
    threshold=None,
    color=None,
    layout=None,
    layout_kwds=None,
    init_pos=None,
    root=0,
    labels=None,
    single_component=False,
    solid_edges="connectivities",
    dashed_edges=None,
    transitions=None,
    fontsize=None,
    fontweight="bold",
    fontoutline=None,
    text_kwds=None,
    node_size_scale=1,
    node_size_power=0.5,
    edge_width_scale=1,
    min_edge_width=None,
    max_edge_width=None,
    arrowsize=30,
    title=None,
    random_state=0,
    pos=None,
    normalize_to_color=False,
    cmap=None,
    cax=None,
    colorbar=False,
    cb_kwds=None,
    frameon=None,
    add_pos=True,
    export_to_gexf=False,
    use_raw=True,
    colors=None,
    groups=None,
    plot=True,
    show=None,
    save=None,
    ax=None,
    **scatter_kwargs,
):
    """scanpy/_paga with some adjustments for directional graphs.
    To be moved back to scanpy once finalized.
    """
    from scanpy.plotting._utils import setup_axes

    if groups is not None:
        labels = groups
    if colors is None:
        colors = color
    groups_key = adata.uns["paga"]["groups"]

    def is_flat(x):
        has_one_per_category = isinstance(
            x, cabc.Collection) and len(x) == len(
                adata.obs[groups_key].cat.categories)
        return has_one_per_category or x is None or isinstance(x, str)

    if is_flat(colors):
        colors = [colors]
    if is_flat(labels):
        labels = [labels for _ in range(len(colors))]
    title = ([c for c in colors] if title is None and len(colors) > 1 else
             [title for _ in colors] if isinstance(title, str) else
             [None for _ in colors] if title is None else title)

    if colorbar is None:
        colorbars = [
            ((c in adata.obs_keys() and adata.obs[c].dtype.name != "category")
             or (c in adata.var_names
                 if adata.raw is None else adata.raw.var_names))
            for c in colors
        ]
    else:
        colorbars = [False for _ in colors]

    if isinstance(root, str):
        if root not in labels:
            raise ValueError(
                f"If `root` is a string, it needs to be one of {labels} not {root!r}."
            )
        root = list(labels).index(root)
    if isinstance(root, cabc.Sequence) and root[0] in labels:
        root = [list(labels).index(r) for r in root]

    # define the adjacency matrices
    if solid_edges not in adata.uns["paga"]:
        logg.warn(f"{solid_edges} not found, using connectivites instead.")
        solid_edges = "connectivities"
    adjacency_solid = adata.uns["paga"][solid_edges].copy()
    adjacency_dashed = None
    if threshold is None:
        threshold = 0.01  # default threshold
    if threshold > 0:
        adjacency_solid.data[adjacency_solid.data < threshold] = 0
        adjacency_solid.eliminate_zeros()
    if dashed_edges is not None:
        adjacency_dashed = adata.uns["paga"][dashed_edges].copy()
        if threshold > 0:
            adjacency_dashed.data[adjacency_dashed.data < threshold] = 0
            adjacency_dashed.eliminate_zeros()

    cats = adata.obs[groups_key].cat.categories
    if pos is not None:
        if isinstance(pos, str):
            if not pos.startswith("X_"):
                pos = f"X_{pos}"
            if pos in adata.obsm.keys():
                X_pos, cg = adata.obsm[pos], adata.obs[groups_key]
                pos = np.stack(
                    [np.median(X_pos[cg == c], axis=0) for c in cats])
            else:
                pos = None
        if len(pos) != len(cats):
            pos = None
    elif init_pos is not None:
        if isinstance(init_pos, str):
            if not init_pos.startswith("X_"):
                init_pos = f"X_{init_pos}"
            if init_pos in adata.obsm.keys():
                X_pos, cg = adata.obsm[init_pos], adata.obs[groups_key]
                init_pos = np.stack(
                    [np.median(X_pos[cg == c], axis=0) for c in cats])
            else:
                init_pos = None
        if len(init_pos) != len(cats):
            init_pos = None

    # compute positions
    if pos is None:
        adj_tree = None
        if layout in {"rt", "rt_circular", "eq_tree"}:
            adj_tree = adata.uns["paga"]["connectivities_tree"]
        pos = _compute_pos(
            adjacency_solid,
            layout=layout,
            random_state=random_state,
            init_pos=init_pos,
            layout_kwds=layout_kwds,
            adj_tree=adj_tree,
            root=root,
        )

    scatter_kwargs.update({"alpha": 0, "color": groups_key})
    x, y = pos[:, 0], pos[:, 1]

    if plot:
        axs_pars = setup_axes(ax=ax, panels=colors, colorbars=colorbars)
        axs, panel_pos, draw_region_width, _ = axs_pars

        if len(colors) == 1 and not isinstance(axs, list):
            axs = [axs]

        for icolor, c in enumerate(colors):
            if title[icolor] is not None:
                axs[icolor].set_title(title[icolor])
            axs[icolor] = scatter(
                adata,
                x=x,
                y=y,
                title=title[icolor],
                ax=axs[icolor],
                save=None,
                zorder=0,
                show=False,
                **scatter_kwargs,
            )
            sct = _paga_graph(
                adata,
                axs[icolor],
                colors=c,
                solid_edges=solid_edges,
                dashed_edges=dashed_edges,
                transitions=transitions,
                threshold=threshold,
                adjacency_solid=adjacency_solid,
                adjacency_dashed=adjacency_dashed,
                root=root,
                labels=labels[icolor],
                fontsize=fontsize,
                fontweight=fontweight,
                fontoutline=fontoutline,
                text_kwds=text_kwds,
                node_size_scale=node_size_scale,
                node_size_power=node_size_power,
                edge_width_scale=edge_width_scale,
                min_edge_width=min_edge_width,
                max_edge_width=max_edge_width,
                normalize_to_color=normalize_to_color,
                frameon=frameon,
                cmap=cmap,
                colorbar=colorbars[icolor],
                cb_kwds=cb_kwds,
                use_raw=use_raw,
                title=title[icolor],
                export_to_gexf=export_to_gexf,
                single_component=single_component,
                arrowsize=arrowsize,
                pos=pos,
            )
            if colorbars[icolor]:
                if cax is None:
                    bottom = panel_pos[0][0]
                    height = panel_pos[1][0] - bottom
                    width = 0.006 * draw_region_width / len(colors)
                    left = panel_pos[2][2 * icolor + 1] + 0.2 * width
                    rectangle = [left, bottom, width, height]
                    fig = pl.gcf()
                    ax_cb = fig.add_axes(rectangle)
                else:
                    ax_cb = cax[icolor]

                pl.colorbar(sct, cax=ax_cb)
    if add_pos:
        adata.uns["paga"]["pos"] = pos
    if plot:
        savefig_or_show("paga", show=show, save=save)
        if len(colors) == 1 and isinstance(axs, list):
            axs = axs[0]
        if show is False:
            return axs
Beispiel #18
0
def neighbors(
    adata,
    n_neighbors=30,
    n_pcs=None,
    use_rep=None,
    use_highly_variable=True,
    knn=True,
    random_state=0,
    method="umap",
    metric="euclidean",
    metric_kwds=None,
    num_threads=-1,
    copy=False,
):
    """
    Compute a neighborhood graph of observations.

    The neighbor graph methods (umap, hnsw, sklearn) only differ in runtime and
    yield the same result as scanpy [Wolf18]_. Connectivities are computed with
    adaptive kernel width as proposed in Haghverdi et al. 2016 (doi:10.1038/nmeth.3971).

    Parameters
    ----------
    adata
        Annotated data matrix.
    n_neighbors
        The size of local neighborhood (in terms of number of neighboring data
        points) used for manifold approximation. Larger values result in more
        global views of the manifold, while smaller values result in more local
        data being preserved. In general values should be in the range 2 to 100.
        If `knn` is `True`, number of nearest neighbors to be searched. If `knn`
        is `False`, a Gaussian kernel width is set to the distance of the
        `n_neighbors` neighbor.
    n_pcs : `int` or `None` (default: None)
        Number of principal components to use.
        If not specified, the full space is used of a pre-computed PCA,
        or 30 components are used when PCA is computed internally.
    use_rep : `None`, `'X'` or any key for `.obsm` (default: None)
        Use the indicated representation. If `None`, the representation is chosen
        automatically: for .n_vars < 50, .X is used, otherwise ‘X_pca’ is used.
    use_highly_variable: `bool` (default: True)
        Whether to use highly variable genes only, stored in .var['highly_variable'].
    knn
        If `True`, use a hard threshold to restrict the number of neighbors to
        `n_neighbors`, that is, consider a knn graph. Otherwise, use a Gaussian
        Kernel to assign low weights to neighbors more distant than the
        `n_neighbors` nearest neighbor.
    random_state
        A numpy random seed.
    method : {{'umap', 'hnsw', 'sklearn'}}  (default: `'umap'`)
        Method to compute neighbors, only differs in runtime.
        The 'hnsw' method is most efficient and requires to `pip install hnswlib`.
        Connectivities are computed with adaptive kernel.
    metric
        A known metric’s name or a callable that returns a distance.
    metric_kwds
        Options for the metric.
    num_threads
        Number of threads to be used (for runtime).
    copy
        Return a copy instead of writing to adata.

    Returns
    -------
    connectivities : `.obsp`
        Sparse weighted adjacency matrix of the neighborhood graph of data
        points. Weights should be interpreted as connectivities.
    distances : `.obsp`
        Sparse matrix of distances for each pair of neighbors.
    """

    adata = adata.copy() if copy else adata

    if use_rep is None:
        use_rep = "X" if adata.n_vars < 50 or n_pcs == 0 else "X_pca"
        n_pcs = None if use_rep == "X" else n_pcs
    elif use_rep not in adata.obsm.keys() and f"X_{use_rep}" in adata.obsm.keys():
        use_rep = f"X_{use_rep}"

    if use_rep == "X_pca":
        if (
            "X_pca" not in adata.obsm.keys()
            or n_pcs is not None
            and n_pcs > adata.obsm["X_pca"].shape[1]
        ):
            n_vars = (
                np.sum(adata.var["highly_variable"])
                if use_highly_variable and "highly_variable" in adata.var.keys()
                else adata.n_vars
            )
            n_comps = min(30 if n_pcs is None else n_pcs, n_vars - 1, adata.n_obs - 1)
            use_highly_variable &= "highly_variable" in adata.var.keys()
            pca(
                adata,
                n_comps=n_comps,
                use_highly_variable=use_highly_variable,
                svd_solver="arpack",
            )
        elif n_pcs is None and adata.obsm["X_pca"].shape[1] < 10:
            logg.warn(
                f"Neighbors are computed on {adata.obsm['X_pca'].shape[1]} "
                f"principal components only."
            )

        n_duplicate_cells = len(get_duplicate_cells(adata))
        if n_duplicate_cells > 0:
            logg.warn(
                f"You seem to have {n_duplicate_cells} duplicate cells in your data.",
                "Consider removing these via pp.remove_duplicate_cells.",
            )

    if metric_kwds is None:
        metric_kwds = {}

    logg.info("computing neighbors", r=True)

    if method == "sklearn":
        from sklearn.neighbors import NearestNeighbors

        X = adata.X if use_rep == "X" else adata.obsm[use_rep]
        neighbors = NearestNeighbors(
            n_neighbors=n_neighbors - 1,
            metric=metric,
            metric_params=metric_kwds,
            n_jobs=num_threads,
        )
        neighbors.fit(X if n_pcs is None else X[:, :n_pcs])
        knn_distances, neighbors.knn_indices = neighbors.kneighbors()
        knn_distances, neighbors.knn_indices = set_diagonal(
            knn_distances, neighbors.knn_indices
        )
        neighbors.distances, neighbors.connectivities = compute_connectivities_umap(
            neighbors.knn_indices, knn_distances, X.shape[0], n_neighbors=n_neighbors
        )

    elif method == "hnsw":
        X = adata.X if use_rep == "X" else adata.obsm[use_rep]
        neighbors = FastNeighbors(n_neighbors=n_neighbors, num_threads=num_threads)
        neighbors.fit(
            X if n_pcs is None else X[:, :n_pcs],
            metric=metric,
            random_state=random_state,
            **metric_kwds,
        )

    else:
        logg.switch_verbosity("off", module="scanpy")
        with warnings.catch_warnings():  # ignore numba warning (umap/issues/252)
            warnings.simplefilter("ignore")
            neighbors = Neighbors(adata)
            neighbors.compute_neighbors(
                n_neighbors=n_neighbors,
                knn=knn,
                n_pcs=n_pcs,
                method=method,
                use_rep=use_rep,
                random_state=random_state,
                metric=metric,
                metric_kwds=metric_kwds,
                write_knn_indices=True,
            )
        logg.switch_verbosity("on", module="scanpy")

    adata.uns["neighbors"] = {}
    try:
        adata.obsp["distances"] = neighbors.distances
        adata.obsp["connectivities"] = neighbors.connectivities
        adata.uns["neighbors"]["connectivities_key"] = "connectivities"
        adata.uns["neighbors"]["distances_key"] = "distances"
    except Exception:
        adata.uns["neighbors"]["distances"] = neighbors.distances
        adata.uns["neighbors"]["connectivities"] = neighbors.connectivities

    if hasattr(neighbors, "knn_indices"):
        adata.uns["neighbors"]["indices"] = neighbors.knn_indices
    adata.uns["neighbors"]["params"] = {
        "n_neighbors": n_neighbors,
        "method": method,
        "metric": metric,
        "n_pcs": n_pcs,
        "use_rep": use_rep,
    }

    logg.info("    finished", time=True, end=" " if settings.verbosity > 2 else "\n")
    logg.hint(
        "added \n"
        "    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)"
    )

    return adata if copy else None
Beispiel #19
0
def moments(
    data,
    n_neighbors=30,
    n_pcs=None,
    mode="connectivities",
    method="umap",
    use_rep=None,
    use_highly_variable=True,
    copy=False,
):
    """Computes moments for velocity estimation.

    First-/second-order moments are computed for each cell across its nearest neighbors,
    where the neighbor graph is obtained from euclidean distances in PCA space.

    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    n_neighbors: `int` (default: 30)
        Number of neighbors to use.
    n_pcs: `int` (default: None)
        Number of principal components to use.
        If not specified, the full space is used of a pre-computed PCA,
        or 30 components are used when PCA is computed internally.
    mode: `'connectivities'` or `'distances'`  (default: `'connectivities'`)
        Distance metric to use for moment computation.
    method : {{'umap', 'hnsw', 'sklearn', `None`}}  (default: `'umap'`)
        Method to compute neighbors, only differs in runtime.
        Connectivities are computed with adaptive kernel width as proposed in
        Haghverdi et al. 2016 (https://doi.org/10.1038/nmeth.3971).
    use_rep : `None`, `'X'` or any key for `.obsm` (default: None)
        Use the indicated representation. If `None`, the representation is chosen
        automatically: for .n_vars < 50, .X is used, otherwise ‘X_pca’ is used.
    use_highly_variable: `bool` (default: True)
        Whether to use highly variable genes only, stored in .var['highly_variable'].
    copy: `bool` (default: `False`)
        Return a copy instead of writing to adata.

    Returns
    -------
    Ms: `.layers`
        dense matrix with first order moments of spliced counts.
    Mu: `.layers`
        dense matrix with first order moments of unspliced counts.
    """

    adata = data.copy() if copy else data

    layers = [
        layer for layer in {"spliced", "unspliced"} if layer in adata.layers
    ]
    if any([not_yet_normalized(adata.layers[layer]) for layer in layers]):
        normalize_per_cell(adata)

    if n_neighbors is not None and n_neighbors > get_n_neighs(adata):
        neighbors(
            adata,
            n_neighbors=n_neighbors,
            use_rep=use_rep,
            use_highly_variable=use_highly_variable,
            n_pcs=n_pcs,
            method=method,
        )
    verify_neighbors(adata)

    if "spliced" not in adata.layers.keys(
    ) or "unspliced" not in adata.layers.keys():
        logg.warn(
            "Skipping moments, because un/spliced counts were not found.")
    else:
        logg.info(f"computing moments based on {mode}", r=True)
        connectivities = get_connectivities(adata,
                                            mode,
                                            n_neighbors=n_neighbors,
                                            recurse_neighbors=False)

        adata.layers["Ms"] = (csr_matrix.dot(
            connectivities,
            csr_matrix(adata.layers["spliced"])).astype(np.float32).A)
        adata.layers["Mu"] = (csr_matrix.dot(
            connectivities,
            csr_matrix(adata.layers["unspliced"])).astype(np.float32).A)
        # if renormalize: normalize_per_cell(adata, layers={'Ms', 'Mu'}, enforce=True)

        logg.info("    finished",
                  time=True,
                  end=" " if settings.verbosity > 2 else "\n")
        logg.hint(
            "added \n"
            "    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)"
        )
    return adata if copy else None
Beispiel #20
0
def scatter(
    adata=None,
    basis=None,
    x=None,
    y=None,
    vkey=None,
    color=None,
    use_raw=None,
    layer=None,
    color_map=None,
    colorbar=None,
    palette=None,
    size=None,
    alpha=None,
    linewidth=None,
    linecolor=None,
    perc=None,
    groups=None,
    sort_order=True,
    components=None,
    projection=None,
    legend_loc=None,
    legend_loc_lines=None,
    legend_fontsize=None,
    legend_fontweight=None,
    legend_fontoutline=None,
    legend_align_text=None,
    xlabel=None,
    ylabel=None,
    title=None,
    fontsize=None,
    figsize=None,
    xlim=None,
    ylim=None,
    add_density=None,
    add_assignments=None,
    add_linfit=None,
    add_polyfit=None,
    add_rug=None,
    add_text=None,
    add_text_pos=None,
    add_margin=None,
    add_outline=None,
    outline_width=None,
    outline_color=None,
    n_convolve=None,
    smooth=None,
    normalize_data=None,
    rescale_color=None,
    color_gradients=None,
    dpi=None,
    frameon=None,
    zorder=None,
    ncols=None,
    nrows=None,
    wspace=None,
    hspace=None,
    show=None,
    save=None,
    ax=None,
    **kwargs,
):
    """\
    Scatter plot along observations or variables axes.

    Arguments
    ---------
    adata: :class:`~anndata.AnnData`
        Annotated data matrix.
    x: `str`, `np.ndarray` or `None` (default: `None`)
        x coordinate
    y: `str`, `np.ndarray` or `None` (default: `None`)
        y coordinate
    {scatter}

    Returns
    -------
    If `show==False` a `matplotlib.Axis`
    """

    if adata is None and (x is not None and y is not None):
        adata = AnnData(np.stack([x, y]).T)

    # restore old conventions
    add_assignments = kwargs.pop("show_assignments", add_assignments)
    add_linfit = kwargs.pop("show_linear_fit", add_linfit)
    add_polyfit = kwargs.pop("show_polyfit", add_polyfit)
    add_density = kwargs.pop("show_density", add_density)
    add_rug = kwargs.pop("rug", add_rug)
    basis = kwargs.pop("var_names", basis)

    # keys for figures (fkeys) and multiple plots (mkeys)
    fkeys = [
        "adata", "show", "save", "groups", "ncols", "nrows", "wspace", "hspace"
    ]
    fkeys += ["add_margin", "ax", "kwargs"]
    mkeys = [
        "color", "layer", "basis", "components", "x", "y", "xlabel", "ylabel"
    ]
    mkeys += ["title", "color_map", "add_text"]
    scatter_kwargs = {"show": False, "save": False}
    for key in signature(scatter).parameters:
        if key not in mkeys + fkeys:
            scatter_kwargs[key] = eval(key)
    mkwargs = {}
    for key in mkeys:  # mkwargs[key] = key for key in mkeys
        mkwargs[key] = eval("{0}[0] if is_list({0}) else {0}".format(key))

    # use c & color and cmap & color_map interchangeably,
    # and plot each group separately if groups is 'all'
    if "c" in kwargs:
        color = kwargs.pop("c")
    if "cmap" in kwargs:
        color_map = kwargs.pop("cmap")
    if "rasterized" not in kwargs:
        kwargs["rasterized"] = settings._vector_friendly
    if isinstance(color_map, (list, tuple)) and all(
        [is_color_like(c) or c == "transparent" for c in color_map]):
        color_map = rgb_custom_colormap(colors=color_map)
    if isinstance(groups, str) and groups == "all":
        if color is None:
            color = default_color(adata)
        if is_categorical(adata, color):
            vc = adata.obs[color].value_counts()
            groups = [[c] for c in vc[vc > 0].index]
    if isinstance(add_text, (list, tuple, np.ndarray, np.record)):
        add_text = list(np.array(add_text, dtype=str))

    # create list of each mkey and check if all bases are valid.
    color = to_list(color, max_len=None)
    layer, components = to_list(layer), to_list(components)
    x, y, basis = to_list(x), to_list(y), to_valid_bases_list(adata, basis)

    # get multikey (with more than one element)
    multikeys = eval(f"[{','.join(mkeys)}]")
    if is_list_of_list(groups):
        multikeys.append(groups)
    key_lengths = np.array(
        [len(key) if is_list(key) else 1 for key in multikeys])
    multikey = (multikeys[np.where(
        key_lengths > 1)[0][0]] if np.max(key_lengths) > 1 else None)

    # gridspec frame for plotting multiple keys (mkeys: list or tuple)
    if multikey is not None:
        if np.sum(key_lengths > 1) == 1 and is_list_of_str(multikey):
            multikey = unique(
                multikey)  # take unique set if no more than one multikey
        if len(multikey) > 20:
            raise ValueError(
                "Please restrict the passed list to max 20 elements.")
        if ax is not None:
            logg.warn("Cannot specify `ax` when plotting multiple panels.")
        if is_list(title):
            title *= int(np.ceil(len(multikey) / len(title)))
        if nrows is None:
            ncols = len(multikey) if ncols is None else min(
                len(multikey), ncols)
            nrows = int(np.ceil(len(multikey) / ncols))
        else:
            ncols = int(np.ceil(len(multikey) / nrows))
        if not frameon or frameon == "artist":
            lloc, llines = "legend_loc", "legend_loc_lines"
            if lloc in scatter_kwargs and scatter_kwargs[lloc] is None:
                scatter_kwargs[lloc] = "none"
            if llines in scatter_kwargs and scatter_kwargs[llines] is None:
                scatter_kwargs[llines] = "none"

        grid_figsize, dpi = get_figure_params(figsize, dpi, ncols)
        grid_figsize = (grid_figsize[0] * ncols, grid_figsize[1] * nrows)
        fig = pl.figure(None, grid_figsize, dpi=dpi)
        hspace = 0.3 if hspace is None else hspace
        gspec = pl.GridSpec(nrows, ncols, fig, hspace=hspace, wspace=wspace)

        ax = []
        for i, gs in enumerate(gspec):
            if i < len(multikey):
                g = groups[i * (len(groups) > i)] if is_list_of_list(
                    groups) else groups
                multi_kwargs = {"groups": g}
                for key in mkeys:  # multi_kwargs[key] = key[i] if is multikey else key
                    multi_kwargs[key] = eval(
                        "{0}[i * (len({0}) > i)] if is_list({0}) else {0}".
                        format(key))
                ax.append(
                    scatter(
                        adata,
                        ax=pl.subplot(gs),
                        **multi_kwargs,
                        **scatter_kwargs,
                        **kwargs,
                    ))

        if not frameon and isinstance(ylabel, str):
            set_label(xlabel, ylabel, fontsize, ax=ax[0], fontweight="bold")
        savefig_or_show(dpi=dpi, save=save, show=show)
        if show is False:
            return ax

    else:
        # make sure that there are no more lists, e.g. ['clusters'] becomes 'clusters'
        color_map = to_val(color_map)
        color, layer, basis = to_val(color), to_val(layer), to_val(basis)
        x, y, components = to_val(x), to_val(y), to_val(components)
        xlabel, ylabel, title = to_val(xlabel), to_val(ylabel), to_val(title)

        # multiple plots within one ax for comma-separated y or layers (string).

        if any([isinstance(key, str) and "," in key for key in [y, layer]]):
            # comma split
            y, layer, color = [
                [k.strip() for k in key.split(",")]
                if isinstance(key, str) and "," in key else to_list(key)
                for key in [y, layer, color]
            ]
            multikey = y if len(y) > 1 else layer if len(layer) > 1 else None

            if multikey is not None:
                for i, mi in enumerate(multikey):
                    ax = scatter(
                        adata,
                        x=x,
                        y=y[i * (len(y) > i)],
                        color=color[i * (len(color) > i)],
                        layer=layer[i * (len(layer) > i)],
                        basis=basis,
                        components=components,
                        groups=groups,
                        xlabel=xlabel,
                        ylabel="expression" if ylabel is None else ylabel,
                        color_map=color_map,
                        title=y[i * (len(y) > i)] if title is None else title,
                        ax=ax,
                        **scatter_kwargs,
                    )
                if legend_loc is None:
                    legend_loc = "best"
                if legend_loc and legend_loc != "none":
                    multikey = [
                        key.replace("Ms", "spliced") for key in multikey
                    ]
                    multikey = [
                        key.replace("Mu", "unspliced") for key in multikey
                    ]
                    ax.legend(multikey,
                              fontsize=legend_fontsize,
                              loc=legend_loc)

                savefig_or_show(dpi=dpi, save=save, show=show)
                if show is False:
                    return ax

        elif color_gradients is not None and color_gradients is not False:
            vals, names, color, scatter_kwargs = gets_vals_from_color_gradients(
                adata, color, **scatter_kwargs)
            cols = zip(adata.obs[color].cat.categories,
                       adata.uns[f"{color}_colors"])
            c_colors = {cat: col for (cat, col) in cols}
            mkwargs.pop("color")
            ax = scatter(
                adata,
                color="grey",
                ax=ax,
                **mkwargs,
                **get_kwargs(scatter_kwargs, {"alpha": 0.05}),
            )  # background
            ax = scatter(
                adata,
                color=color,
                ax=ax,
                **mkwargs,
                **get_kwargs(scatter_kwargs, {"s": 0}),
            )  # set legend
            sorted_idx = np.argsort(vals, 1)[:, ::-1][:, :2]
            for id0 in range(len(names)):
                for id1 in range(id0 + 1, len(names)):
                    cmap = rgb_custom_colormap(
                        [c_colors[names[id0]], "white", c_colors[names[id1]]],
                        alpha=[1, 0, 1],
                    )
                    mkwargs.update({"color_map": cmap})
                    c_vals = np.array(vals[:, id1] - vals[:, id0]).flatten()
                    c_bool = np.array(
                        [id0 in c and id1 in c for c in sorted_idx])
                    if np.sum(c_bool) > 1:
                        _adata = adata[c_bool] if np.sum(
                            ~c_bool) > 0 else adata
                        mkwargs["color"] = c_vals[c_bool]
                        ax = scatter(_adata,
                                     ax=ax,
                                     **mkwargs,
                                     **scatter_kwargs,
                                     **kwargs)
            savefig_or_show(dpi=dpi, save=save, show=show)
            if show is False:
                return ax

        # actual scatter plot
        else:
            # set color, color_map, edgecolor, basis, linewidth, frameon, use_raw
            if color is None:
                color = default_color(adata, add_outline)
            if "cmap" not in kwargs:
                kwargs["cmap"] = (default_color_map(adata, color)
                                  if color_map is None else color_map)
            if "s" not in kwargs:
                kwargs["s"] = default_size(adata) if size is None else size
            if "edgecolor" not in kwargs:
                kwargs["edgecolor"] = "none"
            is_embedding = ((x is None) |
                            (y is None)) and basis not in adata.var_names
            if basis is None and is_embedding:
                basis = default_basis(adata)
            if linewidth is None:
                linewidth = 1
            if frameon is None:
                frameon = True if not is_embedding else settings._frameon
            if isinstance(groups, str):
                groups = [groups]
            if use_raw is None and basis not in adata.var_names:
                use_raw = layer is None and adata.raw is not None

            ax, show = get_ax(ax, show, figsize, dpi, projection)

            # phase portrait: get x and y from .layers (e.g. spliced vs. unspliced)
            if basis in adata.var_names:
                if title is None:
                    title = basis
                if x is None and y is None:
                    x = default_xkey(adata, use_raw=use_raw)
                    y = default_ykey(adata, use_raw=use_raw)
                elif x is None or y is None:
                    raise ValueError("Both x and y have to specified.")
                if isinstance(x, str) and isinstance(y, str):
                    layers_keys = list(adata.layers.keys()) + ["X"]
                    if any([key not in layers_keys for key in [x, y]]):
                        raise ValueError("Could not find x or y in layers.")

                    if xlabel is None:
                        xlabel = x
                    if ylabel is None:
                        ylabel = y

                    x = get_obs_vector(adata, basis, layer=x, use_raw=use_raw)
                    y = get_obs_vector(adata, basis, layer=y, use_raw=use_raw)

                if legend_loc is None:
                    legend_loc = "none"

                if use_raw and perc is not None:
                    ub = np.percentile(
                        x, 99.9 if not isinstance(perc, int) else perc)
                    ax.set_xlim(right=ub * 1.05)
                    ub = np.percentile(
                        y, 99.9 if not isinstance(perc, int) else perc)
                    ax.set_ylim(top=ub * 1.05)

                # velocity model fits (full dynamics and steady-state ratios)
                if any([
                        "gamma" in key or "alpha" in key
                        for key in adata.var.keys()
                ]):
                    plot_velocity_fits(
                        adata,
                        basis,
                        vkey,
                        use_raw,
                        linewidth,
                        linecolor,
                        legend_loc_lines,
                        legend_fontsize,
                        add_assignments,
                        ax=ax,
                    )

            # embedding: set x and y to embedding coordinates
            elif is_embedding:
                X_emb = adata.obsm[
                    f"X_{basis}"][:, get_components(components, basis)]
                x, y = X_emb[:, 0], X_emb[:, 1]
                # todo: 3d plotting
                # z = X_emb[:, 2] if projection == "3d" and X_emb.shape[1] > 2 else None

            elif isinstance(x, str) and isinstance(y, str):
                var_names = (adata.raw.var_names if use_raw
                             and adata.raw is not None else adata.var_names)
                if layer is None:
                    layer = default_xkey(adata, use_raw=use_raw)
                x_keys = list(adata.obs.keys()) + list(adata.layers.keys())
                is_timeseries = y in var_names and x in x_keys
                if xlabel is None:
                    xlabel = x
                if ylabel is None:
                    ylabel = layer if is_timeseries else y
                if title is None:
                    title = y if is_timeseries else color
                if legend_loc is None:
                    legend_loc = "none"

                # gene trend: x and y as gene along obs/layers (e.g. pseudotime)
                if is_timeseries:
                    x = (adata.obs[x] if x in adata.obs.keys() else
                         adata.obs_vector(y, layer=x))
                    y = get_obs_vector(adata,
                                       basis=y,
                                       layer=layer,
                                       use_raw=use_raw)
                # get x and y from var_names, var or obs
                else:
                    if x in var_names and y in var_names:
                        if layer in adata.layers.keys():
                            x = adata.obs_vector(x, layer=layer)
                            y = adata.obs_vector(y, layer=layer)
                        else:
                            data = adata.raw if use_raw else adata
                            x, y = data.obs_vector(x), data.obs_vector(y)
                    elif x in adata.var.keys() and y in adata.var.keys():
                        x, y = adata.var[x], adata.var[y]
                    elif x in adata.obs.keys() and y in adata.obs.keys():
                        x, y = adata.obs[x], adata.obs[y]
                    elif np.any([
                            var_key in x or var_key in y
                            for var_key in adata.var.keys()
                    ]):
                        var_keys = [
                            k for k in adata.var.keys()
                            if not isinstance(adata.var[k][0], str)
                        ]
                        var = adata.var[var_keys]
                        x = var.astype(np.float32).eval(x)
                        y = var.astype(np.float32).eval(y)
                    elif np.any([
                            obs_key in x or obs_key in y
                            for obs_key in adata.obs.keys()
                    ]):
                        obs_keys = [
                            k for k in adata.obs.keys()
                            if not isinstance(adata.obs[k][0], str)
                        ]
                        obs = adata.obs[obs_keys]
                        x = obs.astype(np.float32).eval(x)
                        y = obs.astype(np.float32).eval(y)
                    else:
                        raise ValueError(
                            "x or y is invalid! pass valid observation or a gene name"
                        )

            x, y = make_dense(x).flatten(), make_dense(y).flatten()

            # convolve along x axes (e.g. pseudotime)
            if n_convolve is not None:
                vec_conv = np.ones(n_convolve) / n_convolve
                y[np.argsort(x)] = np.convolve(y[np.argsort(x)],
                                               vec_conv,
                                               mode="same")

            # if color is set to a cell index, plot that cell on top
            if is_int(color) or is_list_of_int(color) and len(color) != len(x):
                color = np.array(np.isin(np.arange(len(x)), color), dtype=bool)
                size = kwargs["s"] * 2 if np.sum(color) == 1 else kwargs["s"]
                if zorder is None:
                    zorder = 10
                ax.scatter(
                    np.ravel(x[color]),
                    np.ravel(y[color]),
                    s=size,
                    zorder=zorder,
                    color=palette[-1] if palette is not None else "darkblue",
                )
                color = (palette[0] if palette is not None and len(palette) > 1
                         else "gold")
                zorder -= 1

            # if color is in {'ascending', 'descending'}
            elif isinstance(color, str):
                if color == "ascending":
                    color = np.linspace(0, 1, len(x))
                elif color == "descending":
                    color = np.linspace(1, 0, len(x))

            # set palette if categorical color vals
            if is_categorical(adata, color):
                set_colors_for_categorical_obs(adata, color, palette)

            # set color
            if (basis in adata.var_names and isinstance(color, str)
                    and color in adata.layers.keys()):
                # phase portrait: color=basis, layer=color
                c = interpret_colorkey(adata, basis, color, perc, use_raw)
            else:
                # embedding, gene trend etc.
                c = interpret_colorkey(adata, color, layer, perc, use_raw)

            if c is not None and not isinstance(c, str) and not isinstance(
                    c[0], str):
                # smooth color values across neighbors and rescale
                if smooth and len(c) == adata.n_obs:
                    n_neighbors = None if isinstance(smooth, bool) else smooth
                    c = get_connectivities(adata,
                                           n_neighbors=n_neighbors).dot(c)
                # rescale color values to min and max acc. to rescale_color tuple
                if rescale_color is not None:
                    try:
                        c += rescale_color[0] - np.nanmin(c)
                        c *= rescale_color[1] / np.nanmax(c)
                    except Exception:
                        logg.warn(
                            "Could not rescale colors. Pass a tuple, e.g. [0,1]."
                        )

            # set vmid to 0 if color values obtained from velocity expression
            if not np.any([v in kwargs
                           for v in ["vmin", "vmid", "vmax"]]) and np.any([
                               isinstance(v, str) and "time" not in v and
                               (v.endswith("velocity")
                                or v.endswith("transition"))
                               for v in [color, layer]
                           ]):
                kwargs["vmid"] = 0

            # introduce vmid by setting vmin and vmax accordingly
            if "vmid" in kwargs:
                vmid = kwargs.pop("vmid")
                if vmid is not None:
                    if not (isinstance(c, str) or isinstance(c[0], str)):
                        lb, ub = np.min(c), np.max(c)
                        crange = max(np.abs(vmid - lb), np.abs(ub - vmid))
                        kwargs.update({
                            "vmin": vmid - crange,
                            "vmax": vmid + crange
                        })

            x, y = np.ravel(x), np.ravel(y)
            if len(x) != len(y):
                raise ValueError("x or y do not share the same dimension.")

            if normalize_data:
                x = (x - np.nanmin(x)) / (np.nanmax(x) - np.nanmin(x))
                y = (y - np.nanmin(x)) / (np.nanmax(y) - np.nanmin(y))

            if not isinstance(c, str):
                c = np.ravel(c) if len(np.ravel(c)) == len(x) else c

            # store original order of color values
            color_array, scatter_array = c, np.stack([x, y]).T

            # set color to grey for NAN values and for cells that are not in groups
            if (groups is not None or is_categorical(adata, color)
                    and np.any(pd.isnull(adata.obs[color]))):
                if isinstance(groups, (list, tuple, np.record)):
                    groups = unique(groups)
                zorder = 0 if zorder is None else zorder
                pop_keys = [
                    "groups", "add_linfit", "add_polyfit", "add_density"
                ]
                _ = [scatter_kwargs.pop(key, None) for key in pop_keys]
                ax = scatter(
                    adata,
                    x=x,
                    y=y,
                    basis=basis,
                    layer=layer,
                    color="lightgrey",
                    ax=ax,
                    **scatter_kwargs,
                )
                if groups is not None and len(groups) == 1:
                    if (isinstance(groups[0], str)
                            and groups[0] in adata.var.keys()
                            and basis in adata.var_names):
                        groups = f"{adata[:, basis].var[groups[0]][0]}"
                idx = groups_to_bool(adata, groups, color)
                if idx is not None:
                    if np.sum(idx) > 0:  # if any group to be highlighted
                        x, y = x[idx], y[idx]
                        if not isinstance(c, str) and len(c) == adata.n_obs:
                            c = c[idx]
                        if isinstance(kwargs["s"], np.ndarray):
                            kwargs["s"] = np.array(kwargs["s"])[idx]
                        if (title is None and groups is not None
                                and len(groups) == 1
                                and isinstance(groups[0], str)):
                            title = groups[0]
                    else:  # if nothing to be highlighted
                        add_linfit, add_polyfit, add_density = None, None, None
            else:
                idx = None

            if not isinstance(c, str) and len(c) != len(x):
                c = "grey"
                if not isinstance(color, str) or color != default_color(adata):
                    logg.warn("Invalid color key. Using grey instead.")

            # check if higher value points should be plotted on top
            if not isinstance(c, str) and len(c) == len(x):
                order = None
                if sort_order and not is_categorical(adata, color):
                    order = np.argsort(c)
                elif not sort_order and is_categorical(adata, color):
                    counts = get_value_counts(
                        adata[idx] if idx is not None else adata, color)
                    np.random.seed(0)
                    nums, p = np.arange(0, len(x)), counts / np.sum(counts)
                    order = np.random.choice(nums, len(x), replace=False, p=p)
                if order is not None:
                    x, y, c = x[order], y[order], c[order]
                    if isinstance(kwargs["s"],
                                  np.ndarray):  # sort sizes if array-type
                        kwargs["s"] = np.array(kwargs["s"])[order]

            marker = kwargs.pop("marker", ".")
            smp = ax.scatter(x,
                             y,
                             c=c,
                             alpha=alpha,
                             marker=marker,
                             zorder=zorder,
                             **kwargs)

            outline_dtypes = (list, tuple, np.ndarray, int, np.int_, str)
            if isinstance(add_outline, outline_dtypes) or add_outline:
                if isinstance(add_outline, (list, tuple, np.record)):
                    add_outline = unique(add_outline)
                if (add_outline is not True
                        and isinstance(add_outline, (int, np.int_))
                        or is_list_of_int(add_outline)
                        and len(add_outline) != len(x)):
                    add_outline = np.isin(np.arange(len(x)), add_outline)
                    add_outline = np.array(add_outline, dtype=bool)
                    if outline_width is None:
                        outline_width = (0.6, 0.3)
                if isinstance(add_outline, str):
                    if add_outline in adata.var.keys(
                    ) and basis in adata.var_names:
                        add_outline = f"{adata[:, basis].var[add_outline][0]}"
                idx = groups_to_bool(adata, add_outline, color)
                if idx is not None and np.sum(
                        idx) > 0:  # if anything to be outlined
                    zorder = 2 if zorder is None else zorder + 2
                    if kwargs["s"] is not None:
                        kwargs["s"] *= 1.2
                    # restore order of values
                    x, y = scatter_array[:, 0][idx], scatter_array[:, 1][idx]
                    c = color_array
                    if not isinstance(c, str) and len(c) == adata.n_obs:
                        c = c[idx]
                    if isinstance(kwargs["s"], np.ndarray):
                        kwargs["s"] = np.array(kwargs["s"])[idx]
                    if isinstance(c, np.ndarray) and not isinstance(c[0], str):
                        if "vmid" not in kwargs and "vmin" not in kwargs:
                            kwargs["vmin"] = np.min(color_array)
                        if "vmid" not in kwargs and "vmax" not in kwargs:
                            kwargs["vmax"] = np.max(color_array)
                    ax.scatter(x,
                               y,
                               c=c,
                               alpha=alpha,
                               marker=".",
                               zorder=zorder,
                               **kwargs)
                if idx is None or np.sum(
                        idx) > 0:  # if all or anything to be outlined
                    plot_outline(x,
                                 y,
                                 kwargs,
                                 outline_width,
                                 outline_color,
                                 zorder,
                                 ax=ax)
                if idx is not None and np.sum(
                        idx) == 0:  # if nothing to be outlined
                    add_linfit, add_polyfit, add_density = None, None, None

            # set legend if categorical categorical color vals
            if is_categorical(adata,
                              color) and len(scatter_array) == adata.n_obs:
                legend_loc = default_legend_loc(adata, color, legend_loc)
                g_bool = groups_to_bool(adata, add_outline, color)
                if not (add_outline is None or g_bool is None):
                    groups = add_outline
                set_legend(
                    adata,
                    ax,
                    color,
                    legend_loc,
                    scatter_array,
                    legend_fontweight,
                    legend_fontsize,
                    legend_fontoutline,
                    legend_align_text,
                    groups,
                )
            if add_density:
                plot_density(x, y, add_density, ax=ax)

            if add_linfit:
                if add_linfit is True and basis in adata.var_names:
                    add_linfit = "no_intercept"  # without intercept
                plot_linfit(
                    x,
                    y,
                    add_linfit,
                    legend_loc != "none",
                    linecolor,
                    linewidth,
                    fontsize,
                    ax=ax,
                )

            if add_polyfit:
                if add_polyfit is True and basis in adata.var_names:
                    add_polyfit = "no_intercept"  # without intercept
                plot_polyfit(
                    x,
                    y,
                    add_polyfit,
                    legend_loc != "none",
                    linecolor,
                    linewidth,
                    fontsize,
                    ax=ax,
                )

            if add_rug:
                rug_color = add_rug if isinstance(add_rug, str) else color
                rug_color = np.ravel(interpret_colorkey(adata, rug_color))
                plot_rug(np.ravel(x), color=rug_color, ax=ax)

            if add_text:
                if add_text_pos is None:
                    add_text_pos = [0.05, 0.95]
                ax.text(
                    add_text_pos[0],
                    add_text_pos[1],
                    f"{add_text}",
                    ha="left",
                    va="top",
                    fontsize=fontsize,
                    transform=ax.transAxes,
                    bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.2),
                )

            set_label(xlabel, ylabel, fontsize, basis, ax=ax)
            set_title(title, layer, color, fontsize, ax=ax)
            update_axes(ax, xlim, ylim, fontsize, is_embedding, frameon,
                        figsize)
            if add_margin:
                set_margin(ax, x, y, add_margin)
            if colorbar is not False:
                if not isinstance(c, str) and not is_categorical(adata, color):
                    labelsize = fontsize * 0.75 if fontsize is not None else None
                    set_colorbar(smp, ax=ax, labelsize=labelsize)

            savefig_or_show(dpi=dpi, save=save, show=show)
            if show is False:
                return ax
Beispiel #21
0
def _paga_graph(
    adata,
    ax,
    solid_edges=None,
    dashed_edges=None,
    adjacency_solid=None,
    adjacency_dashed=None,
    transitions=None,
    threshold=None,
    root=0,
    colors=None,
    labels=None,
    fontsize=None,
    fontweight=None,
    fontoutline=None,
    text_kwds=None,
    node_size_scale=1.0,
    node_size_power=0.5,
    edge_width_scale=1.0,
    normalize_to_color="reference",
    title=None,
    pos=None,
    cmap=None,
    frameon=True,
    min_edge_width=None,
    max_edge_width=None,
    export_to_gexf=False,
    colorbar=None,
    use_raw=True,
    cb_kwds=None,
    single_component=False,
    arrowsize=30,
):
    """scanpy/_paga_graph with some adjustments for directional graphs.
    To be moved back to scanpy once finalized.
    """
    import warnings
    from pathlib import Path

    import networkx as nx
    import pandas as pd
    import scipy
    from pandas.api.types import is_categorical_dtype

    from matplotlib import patheffects
    from matplotlib.colors import is_color_like

    from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation

    node_labels = labels  # rename for clarity
    if (node_labels is not None and isinstance(node_labels, str)
            and node_labels != adata.uns["paga"]["groups"]):
        raise ValueError(
            f"Provide a list of group labels for the PAGA "
            f"groups {adata.uns['paga']['groups']}, not {node_labels}.")
    groups_key = adata.uns["paga"]["groups"]
    if node_labels is None:
        node_labels = adata.obs[groups_key].cat.categories

    if (colors is None or colors == groups_key) and groups_key is not None:
        if f"{groups_key}_colors" not in adata.uns or len(
                adata.obs[groups_key].cat.categories) != len(
                    adata.uns[f"{groups_key}_colors"]):
            add_colors_for_categorical_sample_annotation(adata, groups_key)
        colors = adata.uns[f"{groups_key}_colors"]

    nx_g_solid = nx.Graph(adjacency_solid)
    if dashed_edges is not None:
        nx_g_dashed = nx.Graph(adjacency_dashed)

    # convert pos to array and dict
    if not isinstance(pos, (Path, str)):
        pos_array = pos
    else:
        pos = Path(pos)
        if pos.suffix != ".gdf":
            raise ValueError(
                "Currently only supporting reading positions from .gdf files.")
        s = ""  # read the node definition from the file
        with pos.open() as f:
            f.readline()
            for line in f:
                if line.startswith("edgedef>"):
                    break
                s += line
        from io import StringIO

        df = pd.read_csv(StringIO(s), header=-1)
        pos_array = df[[4, 5]].values

    # convert to dictionary
    pos = {n: [p[0], p[1]] for n, p in enumerate(pos_array)}

    # uniform color
    if isinstance(colors, str) and is_color_like(colors):
        colors = [colors for c in range(len(node_labels))]

    # color degree of the graph
    if isinstance(colors, str) and colors.startswith("degree"):
        # see also tools.paga.paga_degrees
        if colors == "degree_dashed":
            colors = [d for _, d in nx_g_dashed.degree(weight="weight")]
        elif colors == "degree_solid":
            colors = [d for _, d in nx_g_solid.degree(weight="weight")]
        else:
            raise ValueError(
                '`degree` either "degree_dashed" or "degree_solid".')
        colors = (np.array(colors) - np.min(colors)) / (np.max(colors) -
                                                        np.min(colors))

    # plot gene expression
    var_names = adata.var_names if adata.raw is None else adata.raw.var_names
    if isinstance(colors, str) and colors in var_names:
        x_color = []
        cats = adata.obs[groups_key].cat.categories
        for cat in cats:
            subset = (cat == adata.obs[groups_key]).values
            if adata.raw is not None and use_raw:
                adata_gene = adata.raw[:, colors]
            else:
                adata_gene = adata[:, colors]
            x_color.append(np.mean(adata_gene.X[subset]))
        colors = x_color

    # plot continuous annotation
    if (isinstance(colors, str) and colors in adata.obs
            and not is_categorical_dtype(adata.obs[colors])):
        x_color = []
        cats = adata.obs[groups_key].cat.categories
        for cat in cats:
            subset = (cat == adata.obs[groups_key]).values
            x_color.append(adata.obs.loc[subset, colors].mean())
        colors = x_color

    # plot categorical annotation
    if (isinstance(colors, str) and colors in adata.obs
            and is_categorical_dtype(adata.obs[colors])):
        from scanpy._utils import (
            compute_association_matrix_of_groups,
            get_associated_colors_of_groups,
        )

        norm = "reference" if normalize_to_color else "prediction"
        _, asso_matrix = compute_association_matrix_of_groups(
            adata, prediction=groups_key, reference=colors, normalization=norm)
        add_colors_for_categorical_sample_annotation(adata, colors)
        asso_colors = get_associated_colors_of_groups(
            adata.uns[f"{colors}_colors"], asso_matrix)
        colors = asso_colors

    if len(colors) < len(node_labels):
        raise ValueError(
            "`color` list need to be at least as long as `groups`/`node_labels` list."
        )

    # count number of connected components
    n_components, labels = scipy.sparse.csgraph.connected_components(
        adjacency_solid)
    if n_components > 1 and single_component:
        component_sizes = np.bincount(labels)
        largest_component = np.where(
            component_sizes == component_sizes.max())[0][0]
        adjacency_solid = adjacency_solid.tocsr()[labels ==
                                                  largest_component, :]
        adjacency_solid = adjacency_solid.tocsc()[:,
                                                  labels == largest_component]
        colors = np.array(colors)[labels == largest_component]
        node_labels = np.array(node_labels)[labels == largest_component]
        cats_dropped = (adata.obs[groups_key].cat.categories[
            labels != largest_component].tolist())
        logg.info(f"Restricting graph to largest connected component "
                  f"by dropping categories\n{cats_dropped}")
        nx_g_solid = nx.Graph(adjacency_solid)
        if dashed_edges is not None:
            raise ValueError(
                "`single_component` only if `dashed_edges` is `None`.")

    # groups sizes
    if groups_key is not None and f"{groups_key}_sizes" in adata.uns:
        groups_sizes = adata.uns[f"{groups_key}_sizes"]
    else:
        groups_sizes = np.ones(len(node_labels))
    base_scale_scatter = 2000
    base_pie_size = (base_scale_scatter /
                     (np.sqrt(adjacency_solid.shape[0]) + 10) *
                     node_size_scale)
    median_group_size = np.median(groups_sizes)
    groups_sizes = base_pie_size * np.power(groups_sizes / median_group_size,
                                            node_size_power)

    # edge widths
    base_edge_width = edge_width_scale * 5 * rcParams["lines.linewidth"]

    # draw dashed edges
    if dashed_edges is not None:
        widths = [x[-1]["weight"] for x in nx_g_dashed.edges(data=True)]
        widths = base_edge_width * np.array(widths)
        if max_edge_width is not None:
            widths = np.clip(widths, None, max_edge_width)
        nx.draw_networkx_edges(
            nx_g_dashed,
            pos,
            ax=ax,
            width=widths,
            edge_color="grey",
            style="dashed",
            alpha=0.5,
        )

    # draw solid edges
    if transitions is None:
        widths = [x[-1]["weight"] for x in nx_g_solid.edges(data=True)]
        widths = base_edge_width * np.array(widths)
        if min_edge_width is not None or max_edge_width is not None:
            widths = np.clip(widths, min_edge_width, max_edge_width)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            nx.draw_networkx_edges(nx_g_solid,
                                   pos,
                                   ax=ax,
                                   width=widths,
                                   edge_color="black")

    # draw directed edges
    else:
        adjacency_transitions = adata.uns["paga"][transitions].copy()
        if threshold is None:
            threshold = 0.01
        adjacency_transitions.data[adjacency_transitions.data < threshold] = 0
        adjacency_transitions.eliminate_zeros()
        g_dir = nx.DiGraph(adjacency_transitions.T)
        widths = [x[-1]["weight"] for x in g_dir.edges(data=True)]
        widths = base_edge_width * np.array(widths)
        if min_edge_width is not None or max_edge_width is not None:
            widths = np.clip(widths, min_edge_width, max_edge_width)
        nx.draw_networkx_edges(
            g_dir,
            pos,
            ax=ax,
            width=widths,
            edge_color="k",
            arrowsize=arrowsize,
            arrowstyle="-|>",
            node_size=groups_sizes,
        )

    if export_to_gexf:
        if isinstance(colors[0], tuple):
            from matplotlib.colors import rgb2hex

            colors = [rgb2hex(c) for c in colors]
        for count, n in enumerate(nx_g_solid.nodes()):
            nx_g_solid.node[count]["label"] = f"{node_labels[count]}"
            nx_g_solid.node[count]["color"] = f"{colors[count]}"
            nx_g_solid.node[count]["viz"] = dict(position=dict(
                x=1000 * pos[count][0], y=1000 * pos[count][1], z=0))
        filename = settings.writedir / "paga_graph.gexf"
        logg.warn(f"exporting to {filename}")
        settings.writedir.mkdir(parents=True, exist_ok=True)
        nx.write_gexf(nx_g_solid, settings.writedir / "paga_graph.gexf")

    ax.set_frame_on(frameon)
    ax.set_xticks([])
    ax.set_yticks([])

    if fontsize is None:
        fontsize = rcParams["legend.fontsize"]
    if fontoutline is not None:
        text_kwds = dict(text_kwds)
        text_kwds["path_effects"] = [
            patheffects.withStroke(linewidth=fontoutline, foreground="w")
        ]
    # usual scatter plot
    if not isinstance(colors[0], cabc.Mapping):
        n_groups = len(pos_array)
        sct = ax.scatter(
            pos_array[:, 0],
            pos_array[:, 1],
            s=groups_sizes,
            cmap=cmap,
            c=colors[:n_groups],
            edgecolors="face",
            zorder=2,
        )
        for count, group in enumerate(node_labels):
            ax.text(
                pos_array[count, 0],
                pos_array[count, 1],
                group,
                verticalalignment="center",
                horizontalalignment="center",
                size=fontsize,
                fontweight=fontweight,
                **text_kwds,
            )
    # else pie chart plot
    else:

        def transform_ax_coords(a, b):
            return trans2(trans((a, b)))

        # start with this dummy plot... otherwise strange behavior
        sct = ax.scatter(
            pos_array[:, 0],
            pos_array[:, 1],
            alpha=0,
            linewidths=0,
            c="w",
            edgecolors="face",
            s=groups_sizes,
            cmap=cmap,
        )
        bboxes = getbb(sct,
                       ax)  # bounding boxes around the scatterplot markers

        trans = ax.transData.transform
        bbox = ax.get_position().get_points()
        ax_x_min = bbox[0, 0]
        ax_x_max = bbox[1, 0]
        ax_y_min = bbox[0, 1]
        ax_y_max = bbox[1, 1]
        ax_len_x = ax_x_max - ax_x_min
        ax_len_y = ax_y_max - ax_y_min
        trans2 = ax.transAxes.inverted().transform
        pie_axs = []
        for count, (n, box) in enumerate(zip(nx_g_solid.nodes(), bboxes)):
            x0, y0 = transform_ax_coords(box.x0, box.y0)
            x1, y1 = transform_ax_coords(box.x1, box.y1)
            pie_size = np.sqrt(((x0 - x1)**2) + ((y0 - y1)**2))

            xa, ya = transform_ax_coords(*pos[n])
            xa = ax_x_min + (xa - pie_size / 2) * ax_len_x
            ya = ax_y_min + (ya - pie_size / 2) * ax_len_y
            # clip, the fruchterman layout sometimes places below figure
            if ya < 0:
                ya = 0
            if xa < 0:
                xa = 0
            pie_axs.append(
                pl.axes([xa, ya, pie_size * ax_len_x, pie_size * ax_len_y],
                        frameon=False))
            pie_axs[count].set_xticks([])
            pie_axs[count].set_yticks([])
            if not isinstance(colors[count], cabc.Mapping):
                raise ValueError(
                    f"{colors[count]} is neither a dict of valid "
                    "matplotlib colors nor a valid matplotlib color.")
            color_single = colors[count].keys()
            fracs = [colors[count][c] for c in color_single]
            if sum(fracs) < 1:
                color_single = list(color_single)
                color_single.append("grey")
                fracs.append(1 - sum(fracs))
            wedgeprops = dict(linewidth=0, edgecolor="k", antialiased=True)
            pie_axs[count].pie(fracs,
                               colors=color_single,
                               wedgeprops=wedgeprops,
                               normalize=True)
        if node_labels is not None:
            text_kwds.update(
                dict(verticalalignment="center", fontweight=fontweight))
            text_kwds.update(dict(horizontalalignment="center", size=fontsize))
            for ia, a in enumerate(pie_axs):
                a.text(0.5,
                       0.5,
                       node_labels[ia],
                       transform=a.transAxes,
                       **text_kwds)
    return sct
Beispiel #22
0
def get_df(
    data: AnnData,
    keys: Optional[Union[str, List[str]]] = None,
    layer: Optional[str] = None,
    index: List = None,
    columns: List = None,
    sort_values: bool = None,
    dropna: Literal["all", "any"] = "all",
    precision: int = None,
) -> DataFrame:
    """Get dataframe for a specified adata key.

    Return values for specified key
    (in obs, var, obsm, varm, obsp, varp, uns, or layers) as a dataframe.

    Arguments
    ---------
    data
        AnnData object or a numpy array to get values from.
    keys
        Keys from `.var_names`, `.obs_names`, `.var`, `.obs`,
        `.obsm`, `.varm`, `.obsp`, `.varp`, `.uns`, or `.layers`.
    layer
        Layer of `adata` to use as expression values.
    index
        List to set as index.
    columns
        List to set as columns names.
    sort_values
        Wether to sort values by first column (sort_values=True) or a specified column.
    dropna
        Drop columns/rows that contain NaNs in all ('all') or in any entry ('any').
    precision
        Set precision for pandas dataframe.

    Returns
    -------
    :class:`pd.DataFrame`
        A dataframe.
    """

    if precision is not None:
        pd.set_option("precision", precision)

    if isinstance(data, AnnData):
        keys, keys_split = (keys.split("*") if isinstance(keys, str)
                            and "*" in keys else (keys, None))
        keys, key_add = (keys.split("/") if isinstance(keys, str)
                         and "/" in keys else (keys, None))
        keys = [keys] if isinstance(keys, str) else keys
        key = keys[0]

        s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"]
        d_keys = [
            data.obs.keys(),
            data.var.keys(),
            data.obsm.keys(),
            data.varm.keys(),
            data.uns.keys(),
            data.layers.keys(),
        ]

        if hasattr(data, "obsp") and hasattr(data, "varp"):
            s_keys.extend(["obsp", "varp"])
            d_keys.extend([data.obsp.keys(), data.varp.keys()])

        if keys is None:
            df = data.to_df()
        elif key in data.var_names:
            df = obs_df(data, keys, layer=layer)
        elif key in data.obs_names:
            df = var_df(data, keys, layer=layer)
        else:
            if keys_split is not None:
                keys = [
                    k for k in list(data.obs.keys()) + list(data.var.keys())
                    if key in k and keys_split in k
                ]
                key = keys[0]
            s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
            if len(s_key) == 0:
                raise ValueError(
                    f"'{key}' not found in any of {', '.join(s_keys)}.")
            if len(s_key) > 1:
                logg.warn(
                    f"'{key}' found multiple times in {', '.join(s_key)}.")

            s_key = s_key[-1]
            df = getattr(data, s_key)[keys if len(keys) > 1 else key]
            if key_add is not None:
                df = df[key_add]
            if index is None:
                index = (data.var_names if s_key == "varm" else data.obs_names
                         if s_key in {"obsm", "layers"} else None)
                if index is None and s_key == "uns" and hasattr(df, "shape"):
                    key_cats = np.array([
                        key for key in data.obs.keys()
                        if is_categorical_dtype(data.obs[key])
                    ])
                    num_cats = [
                        len(data.obs[key].cat.categories) == df.shape[0]
                        for key in key_cats
                    ]
                    if np.sum(num_cats) == 1:
                        index = data.obs[key_cats[num_cats][0]].cat.categories
                        if (columns is None and len(df.shape) > 1
                                and df.shape[0] == df.shape[1]):
                            columns = index
            elif isinstance(index, str) and index in data.obs.keys():
                index = pd.Categorical(data.obs[index]).categories
            if columns is None and s_key == "layers":
                columns = data.var_names
            elif isinstance(columns, str) and columns in data.obs.keys():
                columns = pd.Categorical(data.obs[columns]).categories
    elif isinstance(data, pd.DataFrame):
        if isinstance(keys, str) and "*" in keys:
            keys, keys_split = keys.split("*")
            keys = [k for k in data.columns if keys in k and keys_split in k]
        df = data[keys] if keys is not None else data
    else:
        df = data

    if issparse(df):
        df = np.array(df.A)
    if columns is None and hasattr(df, "names"):
        columns = df.names

    df = pd.DataFrame(df, index=index, columns=columns)

    if dropna:
        df.replace("", np.nan, inplace=True)
        how = dropna if isinstance(dropna,
                                   str) else "any" if dropna is True else "all"
        df.dropna(how=how, axis=0, inplace=True)
        df.dropna(how=how, axis=1, inplace=True)

    if sort_values:
        sort_by = (sort_values if isinstance(sort_values, str)
                   and sort_values in df.columns else df.columns[0])
        df = df.sort_values(by=sort_by, ascending=False)

    if hasattr(data, "var_names"):
        if df.index[0] in data.var_names:
            df.var_names = df.index
        elif df.columns[0] in data.var_names:
            df.var_names = df.columns
    if hasattr(data, "obs_names"):
        if df.index[0] in data.obs_names:
            df.obs_names = df.index
        elif df.columns[0] in data.obs_names:
            df.obs_names = df.columns

    return df
Beispiel #23
0
def filter_and_normalize(
    data,
    min_counts=None,
    min_counts_u=None,
    min_cells=None,
    min_cells_u=None,
    min_shared_counts=None,
    min_shared_cells=None,
    n_top_genes=None,
    retain_genes=None,
    subset_highly_variable=True,
    flavor="seurat",
    log=True,
    layers_normalize=None,
    copy=False,
    **kwargs,
):
    """Filtering, normalization and log transform

    Expects non-logarithmized data. If using logarithmized data, pass `log=False`.

    Runs the following steps

    .. code:: python

        scv.pp.filter_genes(adata)
        scv.pp.normalize_per_cell(adata)
        if n_top_genes is not None:
            scv.pp.filter_genes_dispersion(adata)
        if log:
            scv.pp.log1p(adata)


    Arguments
    ---------
    data: :class:`~anndata.AnnData`
        Annotated data matrix.
    min_counts: `int` (default: `None`)
        Minimum number of counts required for a gene to pass filtering (spliced).
    min_counts_u: `int` (default: `None`)
        Minimum number of counts required for a gene to pass filtering (unspliced).
    min_cells: `int` (default: `None`)
        Minimum number of cells expressed required to pass filtering (spliced).
    min_cells_u: `int` (default: `None`)
        Minimum number of cells expressed required to pass filtering (unspliced).
    min_shared_counts: `int`, optional (default: `None`)
        Minimum number of counts (both unspliced and spliced) required for a gene.
    min_shared_cells: `int`, optional (default: `None`)
        Minimum number of cells required to be expressed (both unspliced and spliced).
    n_top_genes: `int` (default: `None`)
        Number of genes to keep.
    retain_genes: `list`, optional (default: `None`)
        List of gene names to be retained independent of thresholds.
    subset_highly_variable: `bool` (default: True)
        Whether to subset highly variable genes or to store in .var['highly_variable'].
    flavor: {'seurat', 'cell_ranger', 'svr'}, optional (default: 'seurat')
        Choose the flavor for computing normalized dispersion.
        If choosing 'seurat', this expects non-logarithmized data.
    log: `bool` (default: `True`)
        Take logarithm.
    layers_normalize: list of `str` (default: None)
        List of layers to be normalized.
        If set to None, the layers {'X', 'spliced', 'unspliced'} are considered for
        normalization upon testing whether they have already been normalized
        (by checking type of entries: int -> unprocessed, float -> processed).
    copy: `bool` (default: `False`)
        Return a copy of `adata` instead of updating it.
    **kwargs:
        Keyword arguments passed to pp.normalize_per_cell (e.g. counts_per_cell).

    Returns
    -------
    Returns or updates `adata` depending on `copy`.
    """

    adata = data.copy() if copy else data

    if "spliced" not in adata.layers.keys(
    ) or "unspliced" not in adata.layers.keys():
        logg.warn("Could not find spliced / unspliced counts.")

    filter_genes(
        adata,
        min_counts=min_counts,
        min_counts_u=min_counts_u,
        min_cells=min_cells,
        min_cells_u=min_cells_u,
        min_shared_counts=min_shared_counts,
        min_shared_cells=min_shared_cells,
        retain_genes=retain_genes,
    )

    if layers_normalize is not None and "enforce" not in kwargs:
        kwargs["enforce"] = True
    normalize_per_cell(adata, layers=layers_normalize, **kwargs)

    if n_top_genes is not None:
        filter_genes_dispersion(
            adata,
            n_top_genes=n_top_genes,
            retain_genes=retain_genes,
            flavor=flavor,
            subset=subset_highly_variable,
        )

    log_advised = (np.allclose(adata.X[:10].sum(),
                               adata.layers["spliced"][:10].sum())
                   if "spliced" in adata.layers.keys() else True)

    if log and log_advised:
        log1p(adata)
    if log and log_advised:
        logg.info("Logarithmized X.")
    elif log and not log_advised:
        logg.warn("Did not modify X as it looks preprocessed already.")
    elif log_advised and not log:
        logg.warn(
            "Consider logarithmizing X with `scv.pp.log1p` for better results."
        )

    return adata if copy else None
Beispiel #24
0
def heatmap(
        adata,
        var_names,
        sortby="latent_time",
        layer="Ms",
        color_map="viridis",
        col_color=None,
        palette="viridis",
        n_convolve=30,
        standard_scale=0,
        sort=True,
        colorbar=None,
        col_cluster=False,
        row_cluster=False,
        context=None,
        font_scale=None,
        figsize=(8, 4),
        show=None,
        save=None,
        **kwargs,
):
    """\
    Plot time series for genes as heatmap.

    Arguments
    ---------
    adata: :class:`~anndata.AnnData`
        Annotated data matrix.
    var_names: `str`,  list of `str`
        Names of variables to use for the plot.
    sortby: `str` (default: `'latent_time'`)
        Observation key to extract time data from.
    layer: `str` (default: `'Ms'`)
        Layer key to extract count data from.
    color_map: `str` (default: `'viridis'`)
        String denoting matplotlib color map.
    col_color: `str` or list of `str` (default: `None`)
        String denoting matplotlib color map to use along the columns.
    palette: list of `str` (default: `'viridis'`)
        Colors to use for plotting groups (categorical annotation).
    n_convolve: `int` or `None` (default: `30`)
        If `int` is given, data is smoothed by convolution
        along the x-axis with kernel size n_convolve.
    standard_scale : `int` or `None` (default: `0`)
        Either 0 (rows) or 1 (columns). Whether or not to standardize that dimension
        (each row or column), subtract minimum and divide each by its maximum.
    sort: `bool` (default: `True`)
        Wether to sort the expression values given by xkey.
    colorbar: `bool` or `None` (default: `None`)
        Whether to show colorbar.
    {row,col}_cluster : `bool` or `None`
        If True, cluster the {rows, columns}.
    context : `None`, or one of {paper, notebook, talk, poster}
        A dictionary of parameters or the name of a preconfigured set.
    font_scale : float, optional
        Scaling factor to scale the size of the font elements.
    figsize: tuple (default: `(8,4)`)
        Figure size.
    show: `bool`, optional (default: `None`)
        Show the plot, do not return axis.
    save: `bool` or `str`, optional (default: `None`)
        If `True` or a `str`, save the figure. A string is appended to the default
        filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
    kwargs:
        Arguments passed to seaborns clustermap,
        e.g., set `yticklabels=True` to display all gene names in all rows.

    Returns
    -------
    If `show==False` a `matplotlib.Axis`
    """

    import seaborn as sns

    var_names = [name for name in var_names if name in adata.var_names]

    tkey, xkey = kwargs.pop("tkey", sortby), kwargs.pop("xkey", layer)
    time = adata.obs[tkey].values
    time = time[np.isfinite(time)]

    X = (adata[:, var_names].layers[xkey]
         if xkey in adata.layers.keys() else adata[:, var_names].X)
    if issparse(X):
        X = X.A
    df = pd.DataFrame(X[np.argsort(time)], columns=var_names)

    if n_convolve is not None:
        weights = np.ones(n_convolve) / n_convolve
        for gene in var_names:
            try:
                df[gene] = np.convolve(df[gene].values, weights, mode="same")
            except Exception:
                pass  # e.g. all-zero counts or nans cannot be convolved

    if sort:
        max_sort = np.argsort(np.argmax(df.values, axis=0))
        df = pd.DataFrame(df.values[:, max_sort], columns=df.columns[max_sort])
    strings_to_categoricals(adata)

    if col_color is not None:
        col_colors = to_list(col_color)
        col_color = []
        for _, col in enumerate(col_colors):
            if not is_categorical(adata, col):
                obs_col = adata.obs[col]
                cat_col = np.round(obs_col / np.max(obs_col),
                                   2) * np.max(obs_col)
                adata.obs[f"{col}_categorical"] = pd.Categorical(cat_col)
                col += "_categorical"
                set_colors_for_categorical_obs(adata, col, palette)
            col_color.append(interpret_colorkey(adata, col)[np.argsort(time)])

    if "dendrogram_ratio" not in kwargs:
        kwargs["dendrogram_ratio"] = (
            0.1 if row_cluster else 0,
            0.2 if col_cluster else 0,
        )
    if "cbar_pos" not in kwargs or not colorbar:
        kwargs["cbar_pos"] = None

    kwargs.update(
        dict(
            col_colors=col_color,
            col_cluster=col_cluster,
            row_cluster=row_cluster,
            cmap=color_map,
            xticklabels=False,
            standard_scale=standard_scale,
            figsize=figsize,
        ))

    args = {}
    if font_scale is not None:
        args = {"font_scale": font_scale}
        context = context or "notebook"

    with sns.plotting_context(context=context, **args):
        try:
            cm = sns.clustermap(df.T, **kwargs)
        except Exception:
            logg.warn("Please upgrade seaborn with `pip install -U seaborn`.")
            kwargs.pop("dendrogram_ratio")
            kwargs.pop("cbar_pos")
            cm = sns.clustermap(df.T, **kwargs)

    savefig_or_show("heatmap", save=save, show=show)
    if show is False:
        return cm