示例#1
0
def interactive_hist(adata, keys=['n_counts', 'n_genes'],
                     bins='auto',  max_bins=100,
                     groups=None, fill_alpha=0.4,
                     palette=None, display_all=True,
                     tools='pan, reset, wheel_zoom, save',
                     legend_loc='top_right',
                     plot_width=None, plot_height=None, save=None,
                     *args, **kwargs):
    """Utility function to plot distributions with variable number of bins.

    Params
    --------
    adata: AnnData object
        annotated data object
    keys: list(str), optional (default: `['n_counts', 'n_genes']`)
        keys in `adata.obs` or `adata.var` where the distibutions are stored
    bins: int; str, optional (default: `auto`)
        number of bins used for plotting or str from numpy.histogram
    max_bins: int, optional (default: `1000`)
        maximum number of bins possible
    groups: list(str), (default: `None`)
        keys in `adata.obs.obs_keys()`, groups by all possible combinations of values, e.g. for
        3 plates and 2 time points, we would create total of 6 groups
    fill_alpha: float[0.0, 1.0], (default: `0.4`)
        alpha channel of the fill color
    palette: list(str), optional (default: `None`)
        palette to use
    display_all: bool, optional (default: `True`)
        display the statistics for all data
    tools: str, optional (default: `'pan,reset, wheel_zoom, save'`)
        palette of interactive tools for the user
    legend_loc: str, (default: `'top_right'`)
        position of the legend
    legend_loc: str, default(`'top_left'`)
        position of the legend
    plot_width: int, optional (default: `None`)
        width of the plot
    plot_height: int, optional (default: `None`)
        height of the plot
    save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
        path where to save the plot
    *args, **kwargs: arguments, keyword arguments
        addition argument to bokeh.models.figure

    Returns
    --------
    None
    """

    if max_bins < 1:
        raise ValueError(f'`max_bins` must >= 1')

    palette = Set1[9] + Set2[8] + Set3[12] if palette is None else palette

    # check the input
    for key in keys:
        if key not in adata.obs.keys() and \
           key not in adata.var.keys() and \
           key not in adata.var_names:
            raise ValueError(f'The key `{key}` does not exist in `adata.obs`, `adata.var` or `adata.var_names`.')

    def _create_adata_groups():
        if groups is None:
            return [adata], [('all',)]

        combs = list(product(*[set(adata.obs[g]) for g in groups]))
        adatas= [adata[reduce(lambda l, r: l & r,
                              (adata.obs[k] == v for k, v in zip(groups, vals)), True)]
                 for vals in combs] + [adata]

        if display_all:
            combs += [('all',)]
            adatas += [adata]

        return adatas, combs

    # group_v_combs contains the value combinations
    ad_gs = _create_adata_groups()
    
    cols = []
    for key in keys:
        callbacks = []
        fig = figure(*args, tools=tools, **kwargs)
        slider = Slider(start=1, end=max_bins, value=0, step=1,
                        title='Bins')

        plots = []
        for j, (ad, group_vs) in enumerate(filter(lambda ad_g: ad_g[0].n_obs > 0, zip(*ad_gs))):

            if key in ad.obs.keys():
                orig = ad.obs[key]
                hist, edges = np.histogram(orig, density=True, bins=bins)
            elif key in ad.var.keys():
                orig = ad.var[key]
                hist, edges = np.histogram(orig, density=True, bins=bins)
            else:
                orig = ad[:, key].X
                hist, edges = np.histogram(orig, density=True, bins=bins)

            slider.value = len(hist)
            # case when automatic bins
            max_bins = max(max_bins, slider.value)

            # original data, used for recalculation of histogram in JS code
            orig = ColumnDataSource(data=dict(values=orig))
            # data that we update in JS code
            source = ColumnDataSource(data=dict(hist=hist, l_edges=edges[:-1], r_edges=edges[1:]))

            legend = ', '.join(': '.join(map(str, gv)) for gv in zip(groups, group_vs)) \
                    if groups is not None else 'all'
            p = fig.quad(source=source, top='hist', bottom=0,
                         left='l_edges', right='r_edges',
                         fill_color=palette[j], legend_label=legend if legend_loc is not None else None,
                         muted_alpha=0,
                         line_color="#555555", fill_alpha=fill_alpha)

            # create callback and slider
            callback = CustomJS(args=dict(source=source, orig=orig), code=_inter_hist_js_code)
            callback.args['bins'] = slider
            callbacks.append(callback)

            # add the current plot so that we can set it
            # visible/invisible in JS code
            plots.append(p)

        slider.end = max_bins

        # slider now updates all values
        slider.js_on_change('value', *callbacks)

        button = Button(label='Toggle', button_type='primary')
        button.callback = CustomJS(
            args={'plots': plots},
            code='''
                 for (var i = 0; i < plots.length; i++) {
                     plots[i].muted = !plots[i].muted;
                 }
                 '''
        )

        if legend_loc is not None:
            fig.legend.location = legend_loc
            fig.legend.click_policy = 'mute'

        fig.xaxis.axis_label = key
        fig.yaxis.axis_label = 'normalized frequency'
        _set_plot_wh(fig, plot_width, plot_height)

        cols.append(column(slider, button, fig))

    if _bokeh_version > (1, 0, 4):
        from bokeh.layouts import grid
        plot = grid(children=cols, ncols=2)
    else:
        cols = list(map(list, np.array_split(cols, np.ceil(len(cols) / 2))))
        plot = layout(children=cols, sizing_mode='fixed', ncols=2)

    if save is not None:
        save = save if str(save).endswith('.html') else str(save) + '.html'
        bokeh_save(plot, save)
    else:
        show(plot)
def highlight_indices(adata, key, basis='diffmap', components=[1, 2], cell_keys='',
                     legend_loc='top_right', plot_width=None, plot_height=None,
                     tools='pan, reset, wheel_zoom, save'):
    """
    Plot cell indices. Useful when trying to set adata.uns['iroot'].

    Params
    --------
    adata: AnnData Object
        annotated data object
    key: str
        key in `adata.obs_keys()` to color
    basis: str, optional (default: `'diffmap'`)
        basis to use
    cell_keys: str, list(str), optional (default: `''`)
        keys to display from `adata.obs_keys()` when hovering over cell
    components: list[int], optional (default: `[1, 2]`)
        which components of the basis to use
    legend_loc: str, optional (default `'top_right'`)
        location of the legend
    tools: str, optional (default: `'pan, reset, wheel_zoom, save'`)
        tools for the plot
    plot_width: int, optional (default: `None`)
        width of the plot
    plot_width: int, optional (default: `None`)
        height of the plot

    Returns
    --------
    None
    """

    if key not in adata.obs:
        raise ValueError(f'{key} not found in `adata.obs`')

    if f'X_{basis}' not in adata.obsm_keys():
        raise ValueError(f'basis `X_{basis}` not found in `adata.obsm`')

    if not isinstance(components, type(np.array)):
        components = np.array(components)

    if isinstance(cell_keys, str):
        cell_keys = list(dict.fromkeys(map(str.strip, cell_keys.split(','))))
        if cell_keys != ['']:
            assert all(map(lambda k: k in adata.obs.keys(), cell_keys)), 'Not all keys are in `adata.obs.keys()`.'
        else:
            cell_keys = []

    df = pd.DataFrame(adata.obsm[f'X_{basis}'][:, components - (basis != 'diffmap')], columns=['x', 'y'])

    for k in cell_keys:
        df[k] = list(map(str, adata.obs[k]))

    df['index'] = range(len(df))
    df[key] = list(adata.obs[key])

    if hasattr(adata, 'obs_names'):
        cell_keys.insert(0, 'name')
        df['name'] = list(adata.obs_names)

    if 'index' not in cell_keys:
        cell_keys.insert(0, 'index')

    palette = adata.uns.get(f'{key}_colors', viridis(len(df[key].unique())))

    p = figure(title=f'{key}', tools=tools)
    _set_plot_wh(p, plot_width, plot_height)

    key_col = adata.obs[key].astype('category') if adata.obs[key].dtype.name != 'category' else  adata.obs[key]
    renderers = []
    for c, color in zip(key_col.cat.categories, palette):
        data = ColumnDataSource(df[df[key] == c])
        renderers.append([p.scatter(x='x', y='y', size=10, color=color, source=data, muted_alpha=0)])
    hover_cell = HoverTool(renderers=list(np.ravel(renderers)), tooltips=[(f'{k}', f'@{k}') for k in cell_keys])

    if legend_loc is not None:
        legend = Legend(items=list(zip(map(str, key_col.cat.categories), renderers)), location=legend_loc, click_policy='mute')
        p.add_layout(legend)
        p.legend.location = legend_loc

    p.xaxis.axis_label = f'{basis}_{components[0]}'
    p.yaxis.axis_label = f'{basis}_{components[1]}'

    source = ColumnDataSource(df)
    labels = LabelSet(x='x', y='y', text='index',
                      x_offset=4, y_offset=4,
                      level='glyph', source=source, render_mode='canvas')

    labels.visible = False
    p.add_tools(hover_cell)
    p.add_layout(labels)

    button = Button(label='Toggle Indices', button_type='primary')
    button.callback = CustomJS(args=dict(l=labels), code='l.visible = !l.visible;')

    show(column(button, p))