Exemple #1
0
def _sanitize_anndata(adata: AnnData) -> None:
    """Sanitization and sanity checks on TCR-anndata object. 
    Should be executed by every read_xxx function"""
    assert (len(adata.X.shape) == 2
            ), "X needs to have dimensions, otherwise concat doesn't work. "

    # This should always be a categorical with True / False
    has_tcr_mask = _is_true(adata.obs["has_tcr"])
    adata.obs["has_tcr"] = ["True" if x else "False" for x in has_tcr_mask]
    adata._sanitize()
def _sanitize_anndata(adata: AnnData) -> None:
    """Sanitization and sanity checks on IR-anndata object.
    Should be executed by every read_xxx function"""
    assert (len(adata.X.shape) == 2
            ), "X needs to have dimensions, otherwise concat doesn't work. "

    CATEGORICAL_COLS = ("locus", "v_gene", "d_gene", "j_gene", "c_gene",
                        "multichain")

    # Sanitize has_ir column into categorical
    # This should always be a categorical with True / False
    has_ir_mask = _is_true(adata.obs["has_ir"])
    adata.obs["has_ir"] = pd.Categorical(
        ["True" if x else "False" for x in has_ir_mask])

    # Turn other columns into categorical
    for col in adata.obs.columns:
        if col.endswith(CATEGORICAL_COLS):
            adata.obs[col] = pd.Categorical(adata.obs[col])

    adata._sanitize()
Exemple #3
0
def adata_cdr3():
    obs = pd.DataFrame(
        [
            [
                "cell1",
                "AAA",
                "AHA",
                "KKY",
                "KKK",
                "GCGGCGGCG",
                "TRA",
                "TRB",
                "TRA",
                "TRB",
            ],
            [
                "cell2",
                "AHA",
                "nan",
                "KK",
                "KKK",
                "GCGAUGGCG",
                "TRA",
                "TRB",
                "TRA",
                "TRB",
            ],
            # This row has no chains, but "has_ir" = True. That can happen if
            # the user does not filter the data.
            [
                "cell3",
                "nan",
                "nan",
                "nan",
                "nan",
                "nan",
                "nan",
                "nan",
                "nan",
                "nan",
            ],
            [
                "cell4",
                "AAA",
                "AAA",
                "LLL",
                "AAA",
                "GCUGCUGCU",
                "TRA",
                "TRB",
                "TRA",
                "TRB",
            ],
            [
                "cell5",
                "AAA",
                "nan",
                "LLL",
                "nan",
                "nan",
                "nan",
                "TRB",
                "TRA",
                "nan",
            ],
        ],
        columns=[
            "cell_id",
            "IR_VJ_1_junction_aa",
            "IR_VJ_2_junction_aa",
            "IR_VDJ_1_junction_aa",
            "IR_VDJ_2_junction_aa",
            "IR_VJ_1_junction",
            "IR_VJ_1_locus",
            "IR_VJ_2_locus",
            "IR_VDJ_1_locus",
            "IR_VDJ_2_locus",
        ],
    ).set_index("cell_id")
    obs["has_ir"] = "True"
    adata = AnnData(obs=obs)
    adata._sanitize()
    adata.uns["scirpy_version"] = "0.7"
    return adata
Exemple #4
0
def test_slicing_remove_unused_categories():
    adata = AnnData(np.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
                    dict(k=["a", "a", "b", "b"]))
    adata._sanitize()
    assert adata[2:4].obs["k"].cat.categories.tolist() == ["b"]
Exemple #5
0
def test_slicing_remove_unused_categories():
    adata = AnnData(np.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
                    dict(k=['a', 'a', 'b', 'b']))
    adata._sanitize()
    assert adata[3:5].obs['k'].cat.categories.tolist() == ['b']
Exemple #6
0
def embedding(
    adata: AnnData,
    basis: str,
    *,
    color: Union[str, Sequence[str], None] = None,
    panel_size: Tuple[float, float] = (4, 4),
    palette: Union[str, Cycler, Sequence[str], Sequence[Cycler], None] = None,
    legend_loc: str = "right margin",
    ax: Optional[Union[plt.Axes, Sequence[plt.Axes]]] = None,
    ncols: int = 3,
    show: Optional[bool] = False,
    hspace: float = 0.25,
    wspace: float = None,
    **kwargs,
) -> Union[None, Sequence[plt.Axes]]:
    """A customized wrapper to the :func:`scanpy.pl.embedding` function.

    The differences to the scanpy embedding function are:
        * allows to specify a `panel_size`
        * Allows to specify a different `basis`, `legend_loc` and `palette`
          for each panel. The number of panels is defined by the `color` parameter.
        * Use a patched version for adding "on data" labels. The original
          raises a flood of warnings when coords are `nan`.
        * For columns with many categories, cycles through colors
          instead of reverting to grey
        * allows to specify axes, even if multiple colors are set.

    Parameters
    ----------
    adata
        annotated data matrix
    basis
        embedding to plot.
        Get the coordinates from the "X_{basis}" key in `adata.obsm`.
        This can be a list of the same length as `color` to specify
        different bases for each panel.
    color
        Keys for annotations of observations/cells or variables/genes, e.g.,
        `'ann1'` or `['ann1', 'ann2']`.
    panel_size
        Size tuple (`width`, `height`) of a single panel in inches
    palette
        Colors to use for plotting categorical annotation groups.
        The palette can be a valid :class:`~matplotlib.colors.ListedColormap` name
        (`'Set2'`, `'tab20'`, …) or a :class:`~cycler.Cycler` object.
        It is possible to specify a list of the same size as `color` to choose
        a different color map for each panel.
    legend_loc
        Location of legend, either `'on data'`, `'right margin'` or a valid keyword
        for the `loc` parameter of :class:`~matplotlib.legend.Legend`.
    ax
        A matplotlib axes object or a list with the same length as `color` thereof.
    ncols
        Number of columns for multi-panel plots
    show
        If True, show the firgure. If false, return a list of Axes objects
    wspace
        Adjust the width of the space between multiple panels.
    hspace
        Adjust the height of the space between multiple panels.
    **kwargs
        Arguments to pass to :func:`scanpy.pl.embedding`.

    Returns
    -------
    axes
        A list of axes objects, containing one
        element for each `color`, or None if `show == True`.

    See also
    --------
    :func:`scanpy.pl.embedding`
    """
    adata._sanitize()

    def _make_iterable(var, singleton_types=(str,)):
        return (
            itertools.repeat(var)
            if isinstance(var, singleton_types) or var is None
            else list(var)
        )

    color = [color] if isinstance(color, str) or color is None else list(color)
    basis = _make_iterable(basis)
    legend_loc = _make_iterable(legend_loc)
    palette = _make_iterable(palette, (str, Cycler))

    # set-up grid, if no axes are provided
    if ax is None:
        n_panels = len(color)
        nrows = int(np.ceil(float(n_panels) / ncols))
        ncols = np.min((n_panels, ncols))
        hspace = (
            rcParams.get("figure.subplot.hspace", 0.0) if hspace is None else hspace
        )
        wspace = (
            rcParams.get("figure.subplot.wspace", 0.0) if wspace is None else wspace
        )
        # Don't ask about +/- 1 but appears to be most faithful to the panel size
        fig_width = panel_size[0] * ncols + hspace * (ncols + 1)
        fig_height = panel_size[1] * nrows + wspace * (nrows - 1)
        fig, axs = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            figsize=(fig_width, fig_height),
            gridspec_kw={"wspace": wspace, "hspace": hspace},
            squeeze=False,
        )
        axs = axs.flatten()
    else:
        axs = [ax] if not isinstance(ax, Sequence) else list(ax)
        fig = axs[0].get_figure()

    # use the scanpy plotting api to fill individual components
    for ax, tmp_color, tmp_basis, tmp_legend_loc, tmp_palette in zip(
        axs, color, basis, legend_loc, palette
    ):
        # cycle colors for categories with many values instead of
        # coloring them in grey
        if tmp_palette is None and tmp_color is not None:
            if str(adata.obs[tmp_color].dtype) == "category":
                if adata.obs[tmp_color].unique().size > len(sc.pl.palettes.default_102):
                    tmp_palette = cycler(color=sc.pl.palettes.default_102)

        add_labels = tmp_legend_loc == "on data"
        tmp_legend_loc = None if add_labels else tmp_legend_loc

        sc.pl.embedding(
            adata,
            tmp_basis,
            ax=ax,
            show=False,
            color=tmp_color,
            legend_loc=tmp_legend_loc,
            palette=tmp_palette,
            **kwargs,
        )

        # manually add labels for "on data", as missing entries in `obsm` will cause
        # a flood of matplotlib warnings.
        # TODO: this could eventually be fixed upstream in scanpy
        if add_labels:
            _add_labels(
                ax,
                adata.obsm["X_" + tmp_basis],
                adata.obs[tmp_color].values,
                legend_fontweight=kwargs.get("legend_fontweight", "bold"),
                legend_fontsize=kwargs.get("legend_fontsize", None),
                legend_fontoutline=kwargs.get("legend_fontoutline", None),
            ),

    # hide unused panels in grid
    for ax in axs[len(color) :]:
        ax.axis("off")

    if show:
        fig.show()
    else:
        # only return axes that actually contain a plot.
        return axs[: len(color)]
Exemple #7
0
def test_slicing_remove_unused_categories():
    adata = AnnData(
        np.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
        dict(k=['a', 'a', 'b', 'b']))
    adata._sanitize()
    assert adata[3:5].obs['k'].cat.categories.tolist() == ['b']
Exemple #8
0
def _indexed_expression_df(
    adata: AnnData,
    var_names: Optional[Union[_VarNames, Mapping[str, _VarNames]]] = None,
    groupby: Optional[Union[str, Sequence[str]]] = None,
    use_raw: Optional[bool] = None,
    log: bool = False,
    num_categories: int = 7,
    layer: Optional[str] = None,
    gene_symbols: Optional[str] = None,
    concat_indices: bool = True,
):
    """
    Given the anndata object, prepares a data frame in which the row index are the categories
    defined by group by and the columns correspond to var_names.

    Parameters
    ----------
    adata
        Annotated data matrix.
    var_names
        `var_names` should be a valid subset of `adata.var_names`. All genes are used if no
        given.
    groupby
        The key of the observation grouping to consider. It is expected that
        groupby is a categorical. If groupby is not a categorical observation,
        it would be subdivided into `num_categories`.
    use_raw
        Use `raw` attribute of `adata` if present.
    log
        Use the log of the values
    num_categories
        Only used if groupby observation is not categorical. This value
        determines the number of groups into which the groupby observation
        should be subdivided.
    gene_symbols
        Key for field in .var that stores gene symbols.
    concat_indices
        Concatenates categorical indices into a single categorical index, if 
        groupby is a sequence. True by default.

    Returns
    -------
    Tuple of `pandas.DataFrame` and list of categories.
    """
    from scipy.sparse import issparse

    adata._sanitize()
    if use_raw is None and adata.raw is not None:
        use_raw = True
    if isinstance(var_names, str):
        var_names = [var_names]
    if var_names is None:
        if use_raw:
            var_names = adata.raw.var_names.values
        else:
            var_names = adata.var_names.values

    if groupby is not None:
        if isinstance(groupby, str):
            # if not a list, turn into a list
            groupby = [groupby]
        for group in groupby:
            if group not in adata.obs_keys():
                raise ValueError(
                    'groupby has to be a valid observation. '
                    f'Given {group}, is not in observations: {adata.obs_keys()}'
                )

    if gene_symbols is not None and gene_symbols in adata.var.columns:
        # translate gene_symbols to var_names
        # slow method but gives a meaningful error if no gene symbol is found:
        translated_var_names = []
        # if we're using raw to plot, we should also do gene symbol translations
        # using raw
        if use_raw:
            adata_or_raw = adata.raw
        else:
            adata_or_raw = adata
        for symbol in var_names:
            if symbol not in adata_or_raw.var[gene_symbols].values:
                logg.error(f"Gene symbol {symbol!r} not found in given "
                           f"gene_symbols column: {gene_symbols!r}")
                return
            translated_var_names.append(adata_or_raw.var[
                adata_or_raw.var[gene_symbols] == symbol].index[0])
        symbols = var_names
        var_names = translated_var_names
    if layer is not None:
        if layer not in adata.layers.keys():
            raise KeyError(
                f'Selected layer: {layer} is not in the layers list. '
                f'The list of valid layers is: {adata.layers.keys()}')
        matrix = adata[:, var_names].layers[layer]
    elif use_raw:
        matrix = adata.raw[:, var_names].X
    else:
        matrix = adata[:, var_names].X

    if issparse(matrix):
        matrix = matrix.toarray()
    if log:
        matrix = np.log1p(matrix)

    obs_tidy = pd.DataFrame(matrix, columns=var_names)
    if groupby is None:
        groupby = ''
        obs_tidy_idx = pd.Series(np.repeat('',
                                           len(obs_tidy))).astype('category')
        idx_categories = obs_tidy_idx.cat.categories
    else:
        if len(groupby) == 1 and not is_categorical_dtype(
                adata.obs[groupby[0]]):
            # if the groupby column is not categorical, turn it into one
            # by subdividing into  `num_categories` categories
            obs_tidy_idx = pd.cut(adata.obs[groupby[0]], num_categories)
            idx_categories = obs_tidy_idx.cat.categories
        else:
            assert all(
                is_categorical_dtype(adata.obs[group]) for group in groupby)
            if concat_indices:
                obs_tidy_idx = adata.obs[groupby[0]]
                if len(groupby) > 1:
                    for group in groupby[1:]:
                        # create new category by merging the given groupby categories
                        obs_tidy_idx = (
                            obs_tidy_idx.astype(str) + "_" +
                            adata.obs[group].astype(str)).astype('category')
                obs_tidy_idx.name = "_".join(groupby)
                idx_categories = obs_tidy_idx.cat.categories
            else:
                obs_tidy_idx = [adata.obs[group]
                                for group in groupby]  # keep as multiindex
                idx_categories = [x.cat.categories for x in obs_tidy_idx]

    obs_tidy.set_index(obs_tidy_idx, inplace=True)
    if gene_symbols is not None:
        # translate the column names to the symbol names
        obs_tidy.rename(
            columns={var_names[x]: symbols[x]
                     for x in range(len(var_names))},
            inplace=True,
        )

    return idx_categories, obs_tidy
Exemple #9
0
def clonotype_network(
    adata: AnnData,
    *,
    color: Union[str, Sequence[str], None] = None,
    basis: str = "clonotype_network",
    panel_size: Tuple[float, float] = (10, 10),
    color_by_n_cells: bool = False,
    scale_by_n_cells: bool = True,
    base_size: Optional[float] = None,
    size_power: Optional[float] = None,
    use_raw: Optional[bool] = None,
    show_labels: bool = True,
    label_fontsize: Optional[int] = None,
    label_fontweight: str = "bold",
    label_fontoutline: int = 3,
    label_alpha: float = 0.6,
    label_y_offset: float = 2,
    legend_fontsize=None,
    legend_width=2,
    show_legend: Optional[bool] = None,
    show_size_legend: bool = True,
    palette: Union[str, Sequence[str], Cycler, None] = None,
    cmap: Union[str, Colormap] = None,
    edges_color: Union[str, None] = None,
    edges_cmap: Union[Colormap, str] = COLORMAP_EDGES,
    edges: bool = True,
    edges_width: float = 0.4,
    frameon: Optional[bool] = None,
    title: Optional[str] = None,
    ax: Optional[Axes] = None,
    fig_kws: Optional[dict] = None,
) -> plt.Axes:
    """\
    Plot the :term:`Clonotype` network.

    Requires running :func:`scirpy.tl.clonotype_network` first, to
    compute the layout.

    {clonotype_network}

    When the network is colored by continuous variables (genes, or numeric columns
    from `obs`), the average of the cells in each dot is computed. When the network
    is colored by categorical variables (categorical columns from `obs`), different
    categories per dot are visualized as pie chart.

    The layouting algorithm of :func:`scirpy.tl.clonotype_network` takes point sizes
    into account. For this reason, we recommend providing `base_size` and `size_power`
    already to the tool function.

    Parameters
    ----------
    adata
        Annotated data matrix.
    color
        Keys for annotations of observations/cells or variables/genes,
        e.g. `patient` or `CD8A`.
    basis
        Key under which the graph layout coordinates are stored in `adata.obsm`.
    panel_size
        Size of the main figure panel in inches.
    color_by_n_cells
        Color the nodes by the number of cells they represent. This overrides
        the `color` option.
    scale_by_n_cells
        Scale the nodes by the number of cells they represent. If this is
        set to `True`, we recommend using a "size-aware" layout in
        :func:`scirpy.tl.clonotype_network` to avoid overlapping nodes (default).
    base_size
        Size of a point representing 1 cell. Per default, the value provided
        to :func:`scirpy.tl.clonotype_network` is used. This option allows to
        override this value without recomputing the layout.
    size_power
        Point sizes are raised to the power of this value. Per default, the
        value provided to :func:`scirpy.tl.clonotype_network` is used. This option
        allows to override this value without recomputing the layout.
    use_raw
        Use `adata.raw` for plotting gene expression values. Default: Use `adata.raw`
        if it exists, and `adata` otherwise.
    show_labels
        If `True` plot clonotype ids on top of the subnetworks.
    label_fontsize
        Fontsize for the clonotype labels
    label_fontweight
        Fontweight for the clonotype labels
    label_fontoutline
        Size of the fontoutline added to the clonotype labels. Set to `None` to disable.
    label_alpha
        Transparency of the clonotype labels
    label_y_offset
        Offset the clonotype label on the y axis for better visibility of the
        subnetworks.
    legend_fontsize
        Font-size for the legend.
    show_legend
        Whether to show a legend (when plotting categorical variables)
        or a colorbar (when plotting continuous variables) on the right margin.
        Per default, a legend is shown if the number of categories is smaller than
        50, other wise no legend is shown.
    show_legend_size
        Whether to show a legend for dot sizes on the right margin.
        This option is only applicable if `scale_by_n_cells` is `True`.
    palette
        Colors to use for plotting categorical annotation groups.
        The palette can be a valid :class:`~matplotlib.colors.ListedColormap` name
        (`'Set2'`, `'tab20'`, …) or a :class:`~cycler.Cycler` object.
        a different color map for each panel.
    cmap
        Colormap to use for plotting continuous variables.
    edges_color
        Color of the edges. Set to `None` to color by connectivity and use the
        color map provided by `edges_cmap`.
    edges_cmap
        Colormap to use for coloring edges by connectivity.
    edges
        Whether to show the edges or not.
    edges_width
        width of the edges
    frameon
        Whether to show a frame around the plot
    title
        The main plot title
    ax
        Add the plot to a predefined Axes object.
    cax
        Add the colorbar (if any) to this predefined Axes object.
    fig_kws
        Parameters passed to the :func:`matplotlib.pyplot.figure` call
        if no `ax` is specified.

    Returns
    -------
    A list of axes objects, containing one
    element for each `color`, or None if `show == True`.

    """
    # The plotting code borrows a lot from scanpy.plotting._tools.paga._paga_graph.
    adata._sanitize()
    try:
        clonotype_key = adata.uns[basis]["clonotype_key"]
        base_size = adata.uns[basis][
            "base_size"] if base_size is None else base_size
        size_power = (adata.uns[basis]["size_power"]
                      if size_power is None else size_power)
    except KeyError:
        raise KeyError(
            f"{basis} not found in `adata.uns`. Did you run `tl.clonotype_network`?"
        )
    if f"X_{basis}" not in adata.obsm_keys():
        raise KeyError(
            f"X_{basis} not found in `adata.obsm`. Did you run `tl.clonotype_network`?"
        )
    if clonotype_key not in adata.obs.columns:
        raise KeyError(f"{clonotype_key} not found in adata.obs.")
    if clonotype_key not in adata.uns:
        raise KeyError(f"{clonotype_key} not found in adata.uns.")

    if use_raw is None:
        use_raw = adata.raw is not None

    if frameon is None:
        frameon = settings._frameon

    if show_legend is None:
        if color in adata.obs.columns and is_categorical_dtype(
                adata.obs[color]):
            show_legend = adata.obs[color].nunique() < 50
        else:
            show_legend = True

    clonotype_res = adata.uns[clonotype_key]
    coords, adj_mat = _graph_from_coordinates(adata, clonotype_key)
    nx_graph = nx.Graph(_distance_to_connectivity(adj_mat))

    # Prepare figure
    if ax is None:
        fig_kws = dict() if fig_kws is None else fig_kws
        fig_width = (panel_size[0] if not (show_legend or show_size_legend)
                     else panel_size[0] + legend_width + 0.5)
        fig_kws.update({"figsize": (fig_width, panel_size[1])})
        ax = _init_ax(fig_kws)

    if title is None and color is not None:
        title = color
    ax.set_frame_on(frameon)
    ax.set_xticks([])
    ax.set_yticks([])

    _plot_clonotype_network_panel(
        adata,
        ax,
        legend_width=legend_width,
        color=color,
        coords=coords,
        use_raw=use_raw,
        cell_indices=clonotype_res["cell_indices"],
        nx_graph=nx_graph,
        show_legend=show_legend,
        show_size_legend=show_size_legend,
        show_labels=show_labels,
        label_fontsize=label_fontsize,
        label_fontoutline=label_fontoutline,
        label_fontweight=label_fontweight,
        legend_fontsize=legend_fontsize,
        base_size=base_size,
        size_power=size_power,
        cmap=cmap,
        edges=edges,
        edges_width=edges_width,
        edges_color=edges_color,
        edges_cmap=edges_cmap,
        title=title,
        palette=palette,
        label_alpha=label_alpha,
        label_y_offset=label_y_offset,
        scale_by_n_cells=scale_by_n_cells,
        color_by_n_cells=color_by_n_cells,
    )
    return ax