Beispiel #1
0
def _get_colors(
    adata, obs_key: str, palette: Union[str, Sequence[str], Cycler, None] = None
) -> Dict[str, str]:
    """Return colors for a category stored in AnnData.

    If colors are not stored, new ones are assigned.

    Since we currently don't plot expression values, only keys from `obs`
    are supported, while in scanpy `values_to_plot` (used instead of `obs_key`)
    can be a key from either `obs` or `var`.

    TODO: This makes use of private scanpy functions. This is Evil and
    should be changed in the future.
    """
    # required to turn into categoricals
    adata._sanitize()
    values = adata.obs[obs_key].values
    color_key = f"{obs_key}_colors"
    if palette is not None:
        _set_colors_for_categorical_obs(adata, obs_key, palette)
    elif color_key not in adata.uns or len(adata.uns[color_key]) < len(
        values.categories
    ):
        #  set a default palette in case that no colors or few colors are found
        _set_default_colors_for_categorical_obs(adata, obs_key)
    else:
        _validate_palette(adata, obs_key)

    return {cat: col for cat, col in zip(values.categories, adata.uns[color_key])}
Beispiel #2
0
def trajectory_tree(
    adata: AnnData,
    figsize: tuple = (10, 8),
    node_size: int = 1200,
    font_size: int = 18,
    show: bool = True,
    dpi: int = 600,
    save: Union[str, bool] = False,
):
    """\
    Plot a trajectory tree.

    Parameters
    ----------
    adata : AnnData
        The annotated data matrix.
    figsize : tuple,
        Size of figure, by default (10, 8).
    node_size : int,
        Size of nodes, by default 1200.
    font_size : int,
        Font size, by default 18.
    show : bool,
        If `True`, show the figure. If `False`, return figure, by default `True`.
    save : Union[str, bool]
        If `True` or `str`, save the figure. If a path is specified as `str`, the figure is saved in the path, by default `False`.
    """

    if not isinstance(adata, AnnData):
        ValueError("draw_tree() expects an AnnData argument.")

    groupby = adata.uns["capital"]["tree"]["annotation"]
    groupby_colors = groupby + "_colors"
    if groupby_colors not in adata.uns:
        from scanpy.plotting._utils import _set_default_colors_for_categorical_obs
        _set_default_colors_for_categorical_obs(adata, groupby)

    tree = nx.convert_matrix.from_pandas_adjacency(
        adata.uns["capital"]["tree"]["tree"], create_using=nx.DiGraph)

    clusters = list(adata.obs[groupby].cat.categories)
    dic = dict(zip(clusters, adata.uns[groupby_colors]))
    colorlist = list(pd.Series(list(tree.nodes())).replace(dic))

    plt.figure(figsize=figsize)
    pos = graphviz_layout(tree, prog='dot')
    nx.draw(tree,
            pos,
            node_color=colorlist,
            font_weight='bold',
            font_size=font_size,
            node_size=node_size,
            with_labels=True,
            edge_color="gray",
            arrows=True)

    if save:
        if isinstance(save, str):
            save_dir = save
            os.makedirs(os.path.dirname(save_dir), exist_ok=True)
        elif save is True:
            os.makedirs("./figures", exist_ok=True)
            save_dir = "./figures/tree.png"
        plt.savefig(save_dir, bbox_inches='tight', pad_inches=0.1, dpi=dpi)
    if show:
        plt.show()
        plt.close()
    else:
        plt.close()
Beispiel #3
0
def dtw(
    aligned_data: CapitalData,
    gene: Union[str, list],
    alignment: Union[str, list, None] = None,
    data1_name: Optional[str] = "data1",
    data2_name: Optional[str] = "data2",
    ncols: int = 2,
    widthspace: float = 0.10,
    heightspace: float = 0.30,
    fontsize: float = 12,
    legend_fontsize: float = 12,
    ticksize: float = 12,
    linecolor: str = 'grey',
    dpi: int = 600,
    show: bool = True,
    save: Union[str, bool] = False,
):
    """\
    Plot the results of dynamic time warping.

    Parameters
    ----------
    aligned_data : CapitalData
        The data matrices containing the results of CAPTIAL.
    gene : Union[str, list]
        Keys for annotations of genes.
    alignment : Union[str, list, None], optional
        Keys for alignments to be plotted. If `None`, all alignments will be plotted, by default `None`.
    data1_name : Optional[str], optional
        Text of data1's legend, by default "data1".
    data2_name : Optional[str], optional
        Text of data2's legend, by default "data2".
    ncols : int
        Number of panels per row, by default 2.
    widthspace : float
        Width of space in the panels, by default 0.10.
    heightspace : float
        Height of space in the panels, by default 0.30.
    fontsize : float
        Font size of the title, by default 12.
    legend_fontsize : float, optional
        Font size of the legend, by default 12.
    ticksize : float, optional
        Tick size of the x-axis, by default 12.
    linecolor: str
        Color of lines that connect the cells between data1 and data2.
    show : bool,
        If `True`, show the figure. If `False`, return the figure, by default `True`.
    save : Union[str, bool]
        If `True` or `str`, save the figure. If a path is specified as `str`, the figure is saved in the path, by default `False`.
    """

    if not isinstance(aligned_data, CapitalData):
        ValueError("draw_dtw() expects an CapitalData argument.")

    plot_list = __set_plot_list(aligned_data, alignment, gene)

    data1 = aligned_data.adata1
    data2 = aligned_data.adata2
    groupby1 = data1.uns["capital"]["tree"]["annotation"]
    groupby2 = data2.uns["capital"]["tree"]["annotation"]
    groupby_colors1 = groupby1 + "_colors"
    groupby_colors2 = groupby2 + "_colors"

    if groupby_colors1 not in data1.uns:
        from scanpy.plotting._utils import _set_default_colors_for_categorical_obs
        _set_default_colors_for_categorical_obs(data1, groupby1)
    if groupby_colors2 not in data2.uns:
        from scanpy.plotting._utils import _set_default_colors_for_categorical_obs
        _set_default_colors_for_categorical_obs(data2, groupby2)

    fig, grid = __panel(ncols,
                        len(plot_list),
                        widthspace=widthspace,
                        heightspace=heightspace)

    for count, (alignment, genename) in enumerate(plot_list):
        ax = fig.add_subplot(grid[count])
        dtw_dic = aligned_data.alignmentdict[alignment]
        path = dtw_dic[genename]["path"]
        ordered_cells1 = dtw_dic[genename]["ordered_cells1"]
        ordered_cells2 = dtw_dic[genename]["ordered_cells2"]

        ax.set_title("{}_{}".format(alignment, genename), fontsize=fontsize)
        ax.set_xlabel("Pseudotime", fontsize=fontsize)
        ax.set_aspect(0.8)
        ax.tick_params(labelbottom=True,
                       labelleft=False,
                       labelright=False,
                       labeltop=False,
                       bottom=False,
                       left=False,
                       right=False,
                       top=False,
                       labelsize=ticksize)
        ax.grid(False)

        y1 = data1[ordered_cells1, :].obs["{}_dpt_pseudotime".format(
            alignment)]
        y2 = data2[ordered_cells2, :].obs["{}_dpt_pseudotime".format(
            alignment)]

        clusters1 = data1[ordered_cells1, :].obs[groupby1]
        colors1 = data1[ordered_cells1, :].uns[groupby_colors1]
        clusters2 = data2[ordered_cells2, :].obs[groupby2]
        colors2 = data2[ordered_cells2, :].uns[groupby_colors2]

        if not type(colors1) == np.ndarray:
            colors1 = np.array([colors1])

        if not type(colors2) == np.ndarray:
            colors2 = np.array([colors2])

        dic1 = dict(
            zip(data1[ordered_cells1, :].obs[groupby1].cat.categories,
                colors1))
        dic2 = dict(
            zip(data2[ordered_cells2, :].obs[groupby2].cat.categories,
                colors2))

        colorlist1 = clusters1.replace(dic1)
        colorlist2 = clusters2.replace(dic2)

        ax.scatter(
            np.array(list(y1)),
            np.ones(len(y1)),
            color=colorlist1,
            zorder=-2,
        )

        ax.scatter(
            np.array(list(y2)),
            np.zeros(len(y2)),
            color=colorlist2,
            zorder=-2,
        )

        n_col1 = np.ceil(len(dic1.keys()) / 5).astype(int)
        ordered_clusters1 = [
            cluster
            for cluster in aligned_data.alignmentdict[alignment]['data1']
            if cluster != "#"
        ]
        patches1 = []
        for cluster in ordered_clusters1:
            patches1.append(
                Line2D(range(1),
                       range(1),
                       marker='o',
                       color=dic1[cluster],
                       label=cluster,
                       linewidth=0))
        legend1 = ax.legend(handles=patches1,
                            labels=ordered_clusters1,
                            bbox_to_anchor=(1.05, 0.5),
                            loc='lower left',
                            ncol=n_col1,
                            title="{}".format(data1_name),
                            title_fontsize=legend_fontsize)

        n_col2 = np.ceil(len(dic2.keys()) / 5).astype(int)
        ordered_clusters2 = [
            cluster
            for cluster in aligned_data.alignmentdict[alignment]['data2']
            if cluster != "#"
        ]
        patches2 = []
        for cluster in ordered_clusters2:
            patches2.append(
                Line2D(range(1),
                       range(1),
                       marker='o',
                       color=dic2[cluster],
                       label=cluster,
                       linewidth=0))
        legend2 = ax.legend(handles=patches2,
                            labels=ordered_clusters2,
                            bbox_to_anchor=(1.05, 0.5),
                            loc='upper left',
                            ncol=n_col2,
                            title="{}".format(data2_name),
                            title_fontsize=legend_fontsize)

        ax.add_artist(legend1)
        ax.add_artist(legend2)

        for i, j in path:
            i = int(i)
            j = int(j)
            ax.plot((y1[i], y2[j]), (1, 0),
                    color=linecolor,
                    alpha=0.5,
                    zorder=-3)

        ax.set_rasterization_zorder(-1)

    if save:
        if isinstance(save, str):
            save_dir = save
            os.makedirs(os.path.dirname(save_dir), exist_ok=True)
        elif save is True:
            os.makedirs("./figures", exist_ok=True)
            save_dir = "./figures/dtw.png"
        fig.savefig(save_dir, bbox_inches='tight', pad_inches=0.1, dpi=dpi)
    if show:
        plt.show()
        plt.close()
    else:
        plt.close()
        return fig