def interactive_hist(adata, keys=['n_counts', 'n_genes'], bins='auto', max_bins=100, groups=None, fill_alpha=0.4, palette=None, display_all=True, tools='pan, reset, wheel_zoom, save', legend_loc='top_right', plot_width=None, plot_height=None, save=None, *args, **kwargs): """Utility function to plot distributions with variable number of bins. Params -------- adata: AnnData object annotated data object keys: list(str), optional (default: `['n_counts', 'n_genes']`) keys in `adata.obs` or `adata.var` where the distibutions are stored bins: int; str, optional (default: `auto`) number of bins used for plotting or str from numpy.histogram max_bins: int, optional (default: `1000`) maximum number of bins possible groups: list(str), (default: `None`) keys in `adata.obs.obs_keys()`, groups by all possible combinations of values, e.g. for 3 plates and 2 time points, we would create total of 6 groups fill_alpha: float[0.0, 1.0], (default: `0.4`) alpha channel of the fill color palette: list(str), optional (default: `None`) palette to use display_all: bool, optional (default: `True`) display the statistics for all data tools: str, optional (default: `'pan,reset, wheel_zoom, save'`) palette of interactive tools for the user legend_loc: str, (default: `'top_right'`) position of the legend legend_loc: str, default(`'top_left'`) position of the legend plot_width: int, optional (default: `None`) width of the plot plot_height: int, optional (default: `None`) height of the plot save: Union[os.PathLike, Str, NoneType], optional (default: `None`) path where to save the plot *args, **kwargs: arguments, keyword arguments addition argument to bokeh.models.figure Returns -------- None """ if max_bins < 1: raise ValueError(f'`max_bins` must >= 1') palette = Set1[9] + Set2[8] + Set3[12] if palette is None else palette # check the input for key in keys: if key not in adata.obs.keys() and \ key not in adata.var.keys() and \ key not in adata.var_names: raise ValueError(f'The key `{key}` does not exist in `adata.obs`, `adata.var` or `adata.var_names`.') def _create_adata_groups(): if groups is None: return [adata], [('all',)] combs = list(product(*[set(adata.obs[g]) for g in groups])) adatas= [adata[reduce(lambda l, r: l & r, (adata.obs[k] == v for k, v in zip(groups, vals)), True)] for vals in combs] + [adata] if display_all: combs += [('all',)] adatas += [adata] return adatas, combs # group_v_combs contains the value combinations ad_gs = _create_adata_groups() cols = [] for key in keys: callbacks = [] fig = figure(*args, tools=tools, **kwargs) slider = Slider(start=1, end=max_bins, value=0, step=1, title='Bins') plots = [] for j, (ad, group_vs) in enumerate(filter(lambda ad_g: ad_g[0].n_obs > 0, zip(*ad_gs))): if key in ad.obs.keys(): orig = ad.obs[key] hist, edges = np.histogram(orig, density=True, bins=bins) elif key in ad.var.keys(): orig = ad.var[key] hist, edges = np.histogram(orig, density=True, bins=bins) else: orig = ad[:, key].X hist, edges = np.histogram(orig, density=True, bins=bins) slider.value = len(hist) # case when automatic bins max_bins = max(max_bins, slider.value) # original data, used for recalculation of histogram in JS code orig = ColumnDataSource(data=dict(values=orig)) # data that we update in JS code source = ColumnDataSource(data=dict(hist=hist, l_edges=edges[:-1], r_edges=edges[1:])) legend = ', '.join(': '.join(map(str, gv)) for gv in zip(groups, group_vs)) \ if groups is not None else 'all' p = fig.quad(source=source, top='hist', bottom=0, left='l_edges', right='r_edges', fill_color=palette[j], legend_label=legend if legend_loc is not None else None, muted_alpha=0, line_color="#555555", fill_alpha=fill_alpha) # create callback and slider callback = CustomJS(args=dict(source=source, orig=orig), code=_inter_hist_js_code) callback.args['bins'] = slider callbacks.append(callback) # add the current plot so that we can set it # visible/invisible in JS code plots.append(p) slider.end = max_bins # slider now updates all values slider.js_on_change('value', *callbacks) button = Button(label='Toggle', button_type='primary') button.callback = CustomJS( args={'plots': plots}, code=''' for (var i = 0; i < plots.length; i++) { plots[i].muted = !plots[i].muted; } ''' ) if legend_loc is not None: fig.legend.location = legend_loc fig.legend.click_policy = 'mute' fig.xaxis.axis_label = key fig.yaxis.axis_label = 'normalized frequency' _set_plot_wh(fig, plot_width, plot_height) cols.append(column(slider, button, fig)) if _bokeh_version > (1, 0, 4): from bokeh.layouts import grid plot = grid(children=cols, ncols=2) else: cols = list(map(list, np.array_split(cols, np.ceil(len(cols) / 2)))) plot = layout(children=cols, sizing_mode='fixed', ncols=2) if save is not None: save = save if str(save).endswith('.html') else str(save) + '.html' bokeh_save(plot, save) else: show(plot)
def highlight_indices(adata, key, basis='diffmap', components=[1, 2], cell_keys='', legend_loc='top_right', plot_width=None, plot_height=None, tools='pan, reset, wheel_zoom, save'): """ Plot cell indices. Useful when trying to set adata.uns['iroot']. Params -------- adata: AnnData Object annotated data object key: str key in `adata.obs_keys()` to color basis: str, optional (default: `'diffmap'`) basis to use cell_keys: str, list(str), optional (default: `''`) keys to display from `adata.obs_keys()` when hovering over cell components: list[int], optional (default: `[1, 2]`) which components of the basis to use legend_loc: str, optional (default `'top_right'`) location of the legend tools: str, optional (default: `'pan, reset, wheel_zoom, save'`) tools for the plot plot_width: int, optional (default: `None`) width of the plot plot_width: int, optional (default: `None`) height of the plot Returns -------- None """ if key not in adata.obs: raise ValueError(f'{key} not found in `adata.obs`') if f'X_{basis}' not in adata.obsm_keys(): raise ValueError(f'basis `X_{basis}` not found in `adata.obsm`') if not isinstance(components, type(np.array)): components = np.array(components) if isinstance(cell_keys, str): cell_keys = list(dict.fromkeys(map(str.strip, cell_keys.split(',')))) if cell_keys != ['']: assert all(map(lambda k: k in adata.obs.keys(), cell_keys)), 'Not all keys are in `adata.obs.keys()`.' else: cell_keys = [] df = pd.DataFrame(adata.obsm[f'X_{basis}'][:, components - (basis != 'diffmap')], columns=['x', 'y']) for k in cell_keys: df[k] = list(map(str, adata.obs[k])) df['index'] = range(len(df)) df[key] = list(adata.obs[key]) if hasattr(adata, 'obs_names'): cell_keys.insert(0, 'name') df['name'] = list(adata.obs_names) if 'index' not in cell_keys: cell_keys.insert(0, 'index') palette = adata.uns.get(f'{key}_colors', viridis(len(df[key].unique()))) p = figure(title=f'{key}', tools=tools) _set_plot_wh(p, plot_width, plot_height) key_col = adata.obs[key].astype('category') if adata.obs[key].dtype.name != 'category' else adata.obs[key] renderers = [] for c, color in zip(key_col.cat.categories, palette): data = ColumnDataSource(df[df[key] == c]) renderers.append([p.scatter(x='x', y='y', size=10, color=color, source=data, muted_alpha=0)]) hover_cell = HoverTool(renderers=list(np.ravel(renderers)), tooltips=[(f'{k}', f'@{k}') for k in cell_keys]) if legend_loc is not None: legend = Legend(items=list(zip(map(str, key_col.cat.categories), renderers)), location=legend_loc, click_policy='mute') p.add_layout(legend) p.legend.location = legend_loc p.xaxis.axis_label = f'{basis}_{components[0]}' p.yaxis.axis_label = f'{basis}_{components[1]}' source = ColumnDataSource(df) labels = LabelSet(x='x', y='y', text='index', x_offset=4, y_offset=4, level='glyph', source=source, render_mode='canvas') labels.visible = False p.add_tools(hover_cell) p.add_layout(labels) button = Button(label='Toggle Indices', button_type='primary') button.callback = CustomJS(args=dict(l=labels), code='l.visible = !l.visible;') show(column(button, p))