def _plot_mcse( ax, plotters, length_plotters, rows, cols, figsize, errorbar, rug, data, probs, extra_kwargs, extra_methods, mean_mcse, sd_mcse, rug_kwargs, idata, rug_kind, _markersize, _linewidth, show, ): if ax is None: _, ax = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, backend="bokeh") for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)): if errorbar or rug: values = data[var_name].sel(**selection).values.flatten() if errorbar: quantile_values = _quantile(values, probs) ax_.dash(probs, quantile_values) ax_.multi_line( list(zip(probs, probs)), [(quant - err, quant + err) for quant, err in zip(quantile_values, x)], ) else: ax_.circle(probs, x) if extra_methods: mean_mcse_i = mean_mcse[var_name].sel( **selection).values.item() sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item() hline_mean = Span( location=mean_mcse_i, dimension="width", line_color="black", line_width=extra_kwargs["linewidth"] * 2, line_alpha=extra_kwargs["alpha"], ) ax_.renderers.append(hline_mean) hline_sd = Span( location=sd_mcse_i, dimension="width", line_color="black", line_width=extra_kwargs["linewidth"], line_alpha=extra_kwargs["alpha"], ) ax_.renderers.append(hline_sd) if rug: if rug_kwargs is None: rug_kwargs = {} if not hasattr(idata, "sample_stats"): raise ValueError( "InferenceData object must contain sample_stats for rug plot" ) if not hasattr(idata.sample_stats, rug_kind): raise ValueError( "InferenceData does not contain {} data".format(rug_kind)) rug_kwargs.setdefault("space", 0.1) _rug_kwargs = {} _rug_kwargs.setdefault("size", 8) _rug_kwargs.setdefault("line_color", rug_kwargs.get("line_color", "black")) _rug_kwargs.setdefault("line_width", 1) _rug_kwargs.setdefault("line_alpha", 0.35) _rug_kwargs.setdefault("angle", np.pi / 2) mask = idata.sample_stats[rug_kind].values.flatten() values = rankdata(values)[mask] if errorbar: rug_x, rug_y = ( values / (len(mask) - 1), np.full_like( values, min( 0, min(quantile_values) - (max(quantile_values) - min(quantile_values)) * 0.05, ), ), ) hline = Span( location=min( 0, min(quantile_values) - (max(quantile_values) - min(quantile_values)) * 0.05, ), dimension="width", line_color="black", line_width=_linewidth, line_alpha=0.7, ) else: rug_x, rug_y = ( values / (len(mask) - 1), np.full_like( values, 0, ), ) hline = Span( location=0, dimension="width", line_color="black", line_width=_linewidth, line_alpha=0.7, ) ax_.renderers.append(hline) glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs) cds_rug = ColumnDataSource({ "rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y) }) ax_.add_glyph(cds_rug, glyph) title = Title() title.text = make_label(var_name, selection) ax_.title = title ax_.xaxis.axis_label = "Quantile" ax_.yaxis.axis_label = (r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles") if not errorbar: ax_.y_range._property_values["start"] = -0.05 # pylint: disable=protected-access ax_.y_range._property_values["end"] = 1 # pylint: disable=protected-access if show: grid = gridplot([list(item) for item in ax], toolbar_location="above") bkp.show(grid) return ax
def plot_trace( data, var_names, divergences, kind, figsize, rug, lines, combined, chain_prop, legend, plot_kwargs: [Dict], fill_kwargs: [Dict], rug_kwargs: [Dict], hist_kwargs: [Dict], trace_kwargs: [Dict], rank_kwargs: [Dict], plotters, divergence_data, axes, backend_config, backend_kwargs: [Dict], show, ): """Bokeh traceplot.""" # If divergences are plotted they must be provided if divergences is not False: assert divergence_data is not None if backend_config is None: backend_config = {} backend_config = { **backend_kwarg_defaults(("bounds_y_range", "plot.bokeh.bounds_y_range"), ), **backend_config, } # Set plot default backend kwargs if backend_kwargs is None: backend_kwargs = {} backend_kwargs = { **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"), ), **backend_kwargs, } dpi = backend_kwargs.pop("dpi") backend_kwargs.setdefault("height", int(figsize[1] * dpi // len(plotters))) backend_kwargs.setdefault("width", int(figsize[0] * dpi // 2)) figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) trace_kwargs.setdefault("line_width", linewidth) plot_kwargs.setdefault("line_width", linewidth) if rank_kwargs is None: rank_kwargs = {} if axes is None: axes = [] for i in range(len(plotters)): if i != 0: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs), ] else: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs) ] axes.append(_axes) axes = np.atleast_2d(axes) cds_data = {} cds_var_groups = {} draw_name = "draw" for var_name, selection, value in list( xarray_var_iter(data, var_names=var_names, combined=True)): if selection: cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ([key, value] if ( isinstance(value, str) or not isinstance(value, Iterable)) else [key, *value])), ) else: cds_name = var_name if var_name not in cds_var_groups: cds_var_groups[var_name] = [] cds_var_groups[var_name].append(cds_name) for chain_idx, _ in enumerate(data.chain.values): if chain_idx not in cds_data: cds_data[chain_idx] = {} _data = value[chain_idx] cds_data[chain_idx][cds_name] = _data while any(key == draw_name for key in cds_data[0]): draw_name += "w" for chain_idx in cds_data: cds_data[chain_idx][draw_name] = data.draw.values cds_data = { chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items() } for idx, (var_name, selection, value) in enumerate(plotters): value = np.atleast_2d(value) if len(value.shape) == 2: y_name = (var_name if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ((key, value) if ( isinstance(value, str) or not isinstance(value, Iterable)) else ( key, *value))), )) if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, chain_prop=chain_prop, combined=combined, rug=rug, kind=kind, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, rank_kwargs=rank_kwargs, ) else: for y_name in cds_var_groups[var_name]: if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, chain_prop=chain_prop, combined=combined, rug=rug, kind=kind, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, rank_kwargs=rank_kwargs, ) for col in (0, 1): _title = Title() _title.text = make_label(var_name, selection) axes[idx, col].title = _title axes[idx, col].y_range = DataRange1d( bounds=backend_config["bounds_y_range"], min_interval=0.1) for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection): if isinstance(vlines, (float, int)): line_values = [vlines] else: line_values = np.atleast_1d(vlines).ravel() for line_value in line_values: vline = Span( location=line_value, dimension="height", line_color="black", line_width=1.5, line_alpha=0.75, ) hline = Span( location=line_value, dimension="width", line_color="black", line_width=1.5, line_alpha=trace_kwargs["alpha"], ) axes[idx, 0].renderers.append(vline) axes[idx, 1].renderers.append(hline) if legend: for col in (0, 1): axes[idx, col].legend.location = "top_left" axes[idx, col].legend.click_policy = "hide" else: for col in (0, 1): if axes[idx, col].legend: axes[idx, col].legend.visible = False if divergences: div_density_kwargs = {} div_density_kwargs.setdefault("size", 14) div_density_kwargs.setdefault("line_color", "red") div_density_kwargs.setdefault("line_width", 2) div_density_kwargs.setdefault("line_alpha", 0.50) div_density_kwargs.setdefault("angle", np.pi / 2) div_trace_kwargs = {} div_trace_kwargs.setdefault("size", 14) div_trace_kwargs.setdefault("line_color", "red") div_trace_kwargs.setdefault("line_width", 2) div_trace_kwargs.setdefault("line_alpha", 0.50) div_trace_kwargs.setdefault("angle", np.pi / 2) div_selection = { k: v for k, v in selection.items() if k in divergence_data.dims } divs = divergence_data.sel(**div_selection).values divs = np.atleast_2d(divs) for chain, chain_divs in enumerate(divs): div_idxs = np.arange(len(chain_divs))[chain_divs] if div_idxs.size > 0: values = value[chain, div_idxs] tmp_cds = ColumnDataSource({"y": values, "x": div_idxs}) if divergences == "top": y_div_trace = value.max() else: y_div_trace = value.min() glyph_density = Dash(x="y", y=0.0, **div_density_kwargs) glyph_trace = Dash(x="x", y=y_div_trace, **div_trace_kwargs) axes[idx, 0].add_glyph(tmp_cds, glyph_density) axes[idx, 1].add_glyph(tmp_cds, glyph_trace) show_layout(axes, show) return axes
def plot_kde( density, lower, upper, density_q, xmin, xmax, ymin, ymax, gridsize, values, values2, rug, label, # pylint: disable=unused-argument quantiles, rotated, contour, fill_last, figsize, textsize, # pylint: disable=unused-argument plot_kwargs, fill_kwargs, rug_kwargs, contour_kwargs, contourf_kwargs, pcolormesh_kwargs, is_circular, # pylint: disable=unused-argument ax, legend, # pylint: disable=unused-argument backend_kwargs, show, return_glyph, ): """Bokeh kde plot.""" if backend_kwargs is None: backend_kwargs = {} backend_kwargs = { **backend_kwarg_defaults(), **backend_kwargs, } figsize, *_ = _scale_fig_size(figsize, textsize) if ax is None: ax = create_axes_grid( 1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs, ) glyphs = [] if values2 is None: if plot_kwargs is None: plot_kwargs = {} plot_kwargs.setdefault( "line_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0]) if fill_kwargs is None: fill_kwargs = {} fill_kwargs.setdefault( "fill_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0]) if rug: if rug_kwargs is None: rug_kwargs = {} rug_kwargs = rug_kwargs.copy() if "cds" in rug_kwargs: cds_rug = rug_kwargs.pop("cds") rug_varname = rug_kwargs.pop("y", "y") else: rug_varname = "y" cds_rug = ColumnDataSource({rug_varname: np.asarray(values)}) rug_kwargs.setdefault("size", 8) rug_kwargs.setdefault("line_color", plot_kwargs["line_color"]) rug_kwargs.setdefault("line_width", 1) rug_kwargs.setdefault("line_alpha", 0.35) if not rotated: rug_kwargs.setdefault("angle", np.pi / 2) if isinstance(cds_rug, dict): for _cds_rug in cds_rug.values(): if not rotated: glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs) else: glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs) ax.add_glyph(_cds_rug, glyph) else: if not rotated: glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs) else: glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs) ax.add_glyph(cds_rug, glyph) glyphs.append(glyph) x = np.linspace(lower, upper, len(density)) if quantiles is not None: fill_kwargs.setdefault("fill_alpha", 0.75) fill_kwargs.setdefault("line_color", None) quantiles = sorted(np.clip(quantiles, 0, 1)) if quantiles[0] != 0: quantiles = [0] + quantiles if quantiles[-1] != 1: quantiles = quantiles + [1] for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]): idx = (density_q > quant_0) & (density_q < quant_1) if idx.sum(): patch_x = np.concatenate( (x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]])) patch_y = np.concatenate( (np.zeros_like(density[idx]), [density[idx][-1]], density[idx][::-1], [0])) if not rotated: patch = ax.patch(patch_x, patch_y, **fill_kwargs) else: patch = ax.patch(patch_y, patch_x, **fill_kwargs) glyphs.append(patch) else: if fill_kwargs.get("fill_alpha", False): patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]])) patch_y = np.concatenate((np.zeros_like(density), [density[-1]], density[::-1], [0])) if not rotated: patch = ax.patch(patch_x, patch_y, **fill_kwargs) else: patch = ax.patch(patch_y, patch_x, **fill_kwargs) glyphs.append(patch) if not rotated: line = ax.line(x, density, **plot_kwargs) else: line = ax.line(density, x, **plot_kwargs) glyphs.append(line) else: if contour_kwargs is None: contour_kwargs = {} if contourf_kwargs is None: contourf_kwargs = {} if pcolormesh_kwargs is None: pcolormesh_kwargs = {} g_s = complex(gridsize[0]) x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s] if contour: scaled_density, *scaled_density_args = _scale_axis(density) contour_generator = _contour.QuadContourGenerator( x_x, y_y, scaled_density, None, True, 0) if "levels" in contour_kwargs: levels = contour_kwargs.get("levels") elif "levels" in contourf_kwargs: levels = contourf_kwargs.get("levels") else: levels = 11 if isinstance(levels, Integral): levels_scaled = np.linspace(0, 1, levels) levels = _rescale_axis(levels_scaled, scaled_density_args) else: levels_scaled_nonclip = _scale_axis(np.asarray(levels), scaled_density_args) levels_scaled = np.clip(levels_scaled_nonclip, 0, 1) cmap = contourf_kwargs.pop("cmap", "viridis") if isinstance(cmap, str): cmap = get_cmap(cmap) if isinstance(cmap, Callable): colors = [ rgb2hex(item) for item in cmap(np.linspace(0, 1, len(levels_scaled) + 1)) ] else: colors = cmap contour_kwargs.update(contourf_kwargs) contour_kwargs.setdefault("line_color", "black") contour_kwargs.setdefault("line_alpha", 0.25) contour_kwargs.setdefault("fill_alpha", 1) for i, (level, level_upper, color) in enumerate( zip(levels_scaled[:-1], levels_scaled[1:], colors[1:])): if not fill_last and (i == 0): continue vertices, _ = contour_generator.create_filled_contour( level, level_upper) for seg in vertices: patch = ax.patch(*seg.T, fill_color=color, **contour_kwargs) glyphs.append(patch) if fill_last: ax.background_fill_color = colors[0] ax.xgrid.grid_line_color = None ax.ygrid.grid_line_color = None ax.x_range = Range1d(xmin, xmax) ax.y_range = Range1d(ymin, ymax) else: cmap = pcolormesh_kwargs.pop("cmap", "viridis") if isinstance(cmap, str): cmap = get_cmap(cmap) if isinstance(cmap, Callable): colors = [ rgb2hex(item) for item in cmap(np.linspace(0, 1, 256)) ] else: colors = cmap image = ax.image(image=[density.T], x=xmin, y=ymin, dw=(xmax - xmin) / density.shape[0], dh=(ymax - ymin) / density.shape[1], palette=colors, **pcolormesh_kwargs) glyphs.append(image) ax.x_range.range_padding = ax.y_range.range_padding = 0 show_layout(ax, show) if return_glyph: return ax, glyphs return ax
N = 9 x = np.linspace(-2, 2, N) y = x**2 sizes = np.linspace(10, 20, N) source = ColumnDataSource(dict(x=x, y=y, sizes=sizes)) plot = Plot(title=None, plot_width=300, plot_height=300, min_border=0, toolbar_location=None) glyph = Dash(x="x", y="y", size="sizes", line_color="#3288bd", line_width=1, fill_color=None) plot.add_glyph(source, glyph) xaxis = LinearAxis() plot.add_layout(xaxis, 'below') yaxis = LinearAxis() plot.add_layout(yaxis, 'left') plot.add_layout(Grid(dimension=0, ticker=xaxis.ticker)) plot.add_layout(Grid(dimension=1, ticker=yaxis.ticker)) curdoc().add_root(plot)
def _plot_trace_bokeh( data, var_names=None, coords=None, divergences="bottom", figsize=None, rug=False, lines=None, compact=False, combined=False, legend=False, plot_kwargs=None, fill_kwargs=None, rug_kwargs=None, hist_kwargs=None, trace_kwargs=None, backend_kwargs=None, show=True, ): """Plot distribution (histogram or kernel density estimates) and sampled values. Parameters ---------- data : obj Any object that can be converted to an az.InferenceData object Refer to documentation of az.convert_to_dataset for details var_names : string, or list of strings One or more variables to be plotted. coords : mapping, optional Coordinates of var_names to be plotted. Passed to `Dataset.sel` divergences : {"bottom", "top", None, False} NOT IMPLEMENTED Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y. figsize : figure size tuple If None, size is (12, variables * 2) rug : bool If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous variables. lines : tuple Tuple of (var_name, {'coord': selection}, [line, positions]) to be overplotted as vertical lines on the density and horizontal lines on the trace. compact : bool Plot multidimensional variables in a single plot. combined : bool Flag for combining multiple chains into a single line. If False (default), chains will be plotted separately. legend : bool Add a legend to the figure with the chain color code. plot_kwargs : dict Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. fill_kwargs : dict Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. rug_kwargs : dict Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. hist_kwargs : dict Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables. trace_kwargs : dict Extra keyword arguments passed to `bokeh.plotting.lines` backend_kwargs : dict Extra keyword arguments passed to `bokeh.plotting.figure` show : bool Call `bokeh.plotting.show` for gridded plots `bokeh.layouts.gridplot(axes.tolist())` Returns ------- ndarray axes (bokeh figures) Examples -------- Plot a subset variables .. plot:: :context: close-figs >>> import arviz as az >>> data = az.load_arviz_data('non_centered_eight') >>> coords = {'school': ['Choate', 'Lawrenceville']} >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords) Show all dimensions of multidimensional variables in the same plot .. plot:: :context: close-figs >>> az.plot_trace(data, compact=True) Combine all chains into one distribution .. plot:: :context: close-figs >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, combined=True) Plot reference lines against distribution and trace .. plot:: :context: close-figs >>> lines = (('theta_t',{'school': "Choate"}, [-1]),) >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) """ if divergences: try: divergence_data = convert_to_dataset(data, group="sample_stats").diverging except (ValueError, AttributeError): # No sample_stats, or no `.diverging` divergences = False if coords is None: coords = {} data = get_coords(convert_to_dataset(data, group="posterior"), coords) var_names = _var_names(var_names, data) if divergences: divergence_data = get_coords( divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")} ) if lines is None: lines = () num_colors = len(data.chain) + 1 if combined else len(data.chain) colors = [ prop for _, prop in zip( range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) ) ] if compact: skip_dims = set(data.dims) - {"chain", "draw"} else: skip_dims = set() plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) max_plots = rcParams["plot.max_subplots"] max_plots = len(plotters) if max_plots is None else max_plots if len(plotters) > max_plots: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " "of variables to plot ({len_plotters}), generating only {max_plots} " "plots".format(max_plots=max_plots, len_plotters=len(plotters)), SyntaxWarning, ) plotters = plotters[:max_plots] if figsize is None: figsize = (12, len(plotters) * 2) if trace_kwargs is None: trace_kwargs = {} trace_kwargs.setdefault("alpha", 0.35) if hist_kwargs is None: hist_kwargs = {} if plot_kwargs is None: plot_kwargs = {} if fill_kwargs is None: fill_kwargs = {} if rug_kwargs is None: rug_kwargs = {} hist_kwargs.setdefault("alpha", 0.35) figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) figsize = int(figsize[0] * 90 // 2), int(figsize[1] * 90 // len(plotters)) trace_kwargs.setdefault("line_width", linewidth) plot_kwargs.setdefault("line_width", linewidth) if backend_kwargs is None: backend_kwargs = dict() backend_kwargs.setdefault( "tools", ("pan,wheel_zoom,box_zoom," "lasso_select,poly_select," "undo,redo,reset,save,hover"), ) backend_kwargs.setdefault("output_backend", "webgl") backend_kwargs.setdefault("height", figsize[1]) backend_kwargs.setdefault("width", figsize[0]) axes = [] for i in range(len(plotters)): if i != 0: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs), ] else: _axes = [bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs)] axes.append(_axes) axes = np.array(axes) cds_data = {} cds_var_groups = {} draw_name = "draw" for var_name, selection, value in list( xarray_var_iter(data, var_names=var_names, combined=True) ): if selection: cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ( [key, value] if (isinstance(value, str) or not isinstance(value, Iterable)) else [key, *value] ) ), ) else: cds_name = var_name if var_name not in cds_var_groups: cds_var_groups[var_name] = [] cds_var_groups[var_name].append(cds_name) for chain_idx, _ in enumerate(data.chain.values): if chain_idx not in cds_data: cds_data[chain_idx] = {} _data = value[chain_idx] cds_data[chain_idx][cds_name] = _data while any(key == draw_name for key in cds_data[0]): draw_name += "w" for chain_idx in cds_data: cds_data[chain_idx][draw_name] = data.draw.values cds_data = {chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items()} for idx, (var_name, selection, value) in enumerate(plotters): value = np.atleast_2d(value) if len(value.shape) == 2: y_name = ( var_name if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ( (key, value) if (isinstance(value, str) or not isinstance(value, Iterable)) else (key, *value) ) ), ) ) if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, colors=colors, combined=combined, rug=rug, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, ) else: for y_name in cds_var_groups[var_name]: if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, colors=colors, combined=combined, rug=rug, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, ) for col in (0, 1): _title = Title() _title.text = make_label(var_name, selection) axes[idx, col].title = _title for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection): if isinstance(vlines, (float, int)): line_values = [vlines] else: line_values = np.atleast_1d(vlines).ravel() for line_value in line_values: vline = Span( location=line_value, dimension="height", line_color="black", line_width=1.5, line_alpha=0.75, ) hline = Span( location=line_value, dimension="width", line_color="black", line_width=1.5, line_alpha=trace_kwargs["alpha"], ) axes[idx, 0].renderers.append(vline) axes[idx, 1].renderers.append(hline) if legend: for col in (0, 1): axes[idx, col].legend.location = "top_left" axes[idx, col].legend.click_policy = "hide" else: for col in (0, 1): if axes[idx, col].legend: axes[idx, col].legend.visible = False if divergences: div_density_kwargs = {} div_density_kwargs.setdefault("size", 14) div_density_kwargs.setdefault("line_color", "red") div_density_kwargs.setdefault("line_width", 2) div_density_kwargs.setdefault("line_alpha", 0.50) div_density_kwargs.setdefault("angle", np.pi / 2) div_trace_kwargs = {} div_trace_kwargs.setdefault("size", 14) div_trace_kwargs.setdefault("line_color", "red") div_trace_kwargs.setdefault("line_width", 2) div_trace_kwargs.setdefault("line_alpha", 0.50) div_trace_kwargs.setdefault("angle", np.pi / 2) div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims} divs = divergence_data.sel(**div_selection).values divs = np.atleast_2d(divs) for chain, chain_divs in enumerate(divs): div_idxs = np.arange(len(chain_divs))[chain_divs] if div_idxs.size > 0: values = value[chain, div_idxs] tmp_cds = ColumnDataSource({"y": values, "x": div_idxs}) if divergences == "top": y_div_trace = value.max() else: y_div_trace = value.min() glyph_density = Dash(x="y", y=0.0, **div_density_kwargs) glyph_trace = Dash(x="x", y=y_div_trace, **div_trace_kwargs) axes[idx, 0].add_glyph(tmp_cds, glyph_density) axes[idx, 1].add_glyph(tmp_cds, glyph_trace) if show: grid = gridplot([list(item) for item in axes], toolbar_location="above") bkp.show(grid) return axes
line_color="#F0027F", fill_color=None, line_width=2)), ("x", X(x="x", y="y", size="sizes", line_color="thistle", fill_color=None, line_width=2)), ("hex", Hex(x="x", y="y", size="sizes", line_color="#99D594", line_width=2)), ("dash", Dash(x="x", y="y", size="sizes", angle=0.5, line_color="#386CB0", line_width=1)), ] def make_tab(title, glyph): plot = Plot() plot.title.text = title plot.add_glyph(source, glyph) xaxis = LinearAxis() plot.add_layout(xaxis, 'below') yaxis = LinearAxis()
def plot_ess( ax, plotters, xdata, ess_tail_dataset, mean_ess, sd_ess, idata, data, text_x, text_va, kind, extra_methods, rows, cols, figsize, kwargs, extra_kwargs, text_kwargs, _linewidth, _markersize, n_samples, relative, min_ess, xt_labelsize, titlesize, ax_labelsize, ylabel, rug, rug_kind, rug_kwargs, hline_kwargs, backend_kwargs, show, ): """Bokeh essplot.""" if backend_kwargs is None: backend_kwargs = {} backend_kwargs = { **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"), ), **backend_kwargs, } if ax is None: _, ax = _create_axes_grid( len(plotters), rows, cols, figsize=figsize, squeeze=False, constrained_layout=True, backend="bokeh", backend_kwargs=backend_kwargs, ) else: ax = np.atleast_2d(ax) for (var_name, selection, x), ax_ in zip(plotters, (item for item in ax.flatten() if item is not None)): bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6) if kind == "evolution": bulk_line = ax_.line(np.asarray(xdata), np.asarray(x)) ess_tail = ess_tail_dataset[var_name].sel(**selection) tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange") tail_line = ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange") elif rug: if rug_kwargs is None: rug_kwargs = {} if not hasattr(idata, "sample_stats"): raise ValueError( "InferenceData object must contain sample_stats for rug plot" ) if not hasattr(idata.sample_stats, rug_kind): raise ValueError( "InferenceData does not contain {} data".format(rug_kind)) rug_kwargs.setdefault("space", 0.1) _rug_kwargs = {} _rug_kwargs.setdefault("size", 8) _rug_kwargs.setdefault("line_color", rug_kwargs.get("line_color", "black")) _rug_kwargs.setdefault("line_width", 1) _rug_kwargs.setdefault("line_alpha", 0.35) _rug_kwargs.setdefault("angle", np.pi / 2) values = data[var_name].sel(**selection).values.flatten() mask = idata.sample_stats[rug_kind].values.flatten() values = rankdata(values)[mask] rug_space = np.max(x) * rug_kwargs.pop("space") rug_x, rug_y = values / (len(mask) - 1), np.zeros_like(values) - rug_space glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs) cds_rug = ColumnDataSource({ "rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y) }) ax_.add_glyph(cds_rug, glyph) hline = Span( location=0, dimension="width", line_color="black", line_width=_linewidth, line_alpha=0.7, ) ax_.renderers.append(hline) if extra_methods: mean_ess_i = mean_ess[var_name].sel(**selection).values.item() sd_ess_i = sd_ess[var_name].sel(**selection).values.item() hline = Span( location=mean_ess_i, dimension="width", line_color="black", line_width=2, line_dash="dashed", line_alpha=1.0, ) ax_.renderers.append(hline) hline = Span( location=sd_ess_i, dimension="width", line_color="black", line_width=1, line_dash="dashed", line_alpha=1.0, ) ax_.renderers.append(hline) hline = Span( location=400 / n_samples if relative else min_ess, dimension="width", line_color="red", line_width=3, line_dash="dashed", line_alpha=1.0, ) ax_.renderers.append(hline) if kind == "evolution": legend = Legend( items=[("bulk", [bulk_points, bulk_line]), ("tail", [tail_line, tail_points])], location="center_right", orientation="horizontal", ) ax_.add_layout(legend, "above") ax_.legend.click_policy = "hide" title = Title() title.text = make_label(var_name, selection) ax_.title = title ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile" ax_.yaxis.axis_label = ylabel.format( "Relative ESS" if relative else "ESS") show_layout(ax, show) return ax
def whisker_sentiment(df, calc_attr): # filters for not negative values to better analyze the mean trends df = df[df[calc_attr] != 0] df = df.sort_values(by=['day_range']).copy() calc_attr_capitalized = calc_attr.capitalize() # create figure p = figure(plot_width=400, plot_height=350, y_axis_label=calc_attr_capitalized) base, lower, upper, mean = [], [], [], [] for i, date in enumerate(list(df.day_range.unique())): day_pol = df[df['day_range'] == date][calc_attr] pol_mean = day_pol.mean() pol_std = day_pol.std() lower.append(pol_mean - pol_std) upper.append(pol_mean + pol_std) base.append(date) mean.append(pol_mean) source_error = ColumnDataSource(data={ 'base': base, 'lower': lower, 'upper': upper }) source_mean = ColumnDataSource(data={'x': base, 'y': mean}) # plot scatter plots for each time point in series for i, date in enumerate(list(df.day_range.unique())): y = df[df['day_range'] == date][calc_attr] p.circle(x=date, y=y, color='teal', size=3, alpha=0.1) p.add_layout( Whisker(source=source_error, base="base", upper="upper", lower="lower", line_color='pink', line_width=5)) # indicator for mean values as dash added to plot dash = Dash(x="x", y="y", size=10, line_color="black", line_width=1, fill_color=None) p.add_glyph(source_mean, dash) # adjust min and max of x axis and distribute major axis ticks over date # ranges p.x_range = Range1d(0, max(base) + 1) p.xaxis.ticker = np.linspace(min(base), max(base), 6, dtype='int').tolist() # call create tick date dict function to convert a list of day sequence of # the data series to a dict containing keys as list of days and values # the according dates. This dict format is required for FuncTickFormatter # JS code to change the x labels to according dates date_tick_dict = create_tick_date_dict(df) p.xaxis.formatter = FuncTickFormatter(code=""" var mapping = %s; return mapping[tick]; """ % date_tick_dict) # configure visual properties on a plot's title attribute p.title.text = f"Mean {calc_attr_capitalized} Over Time" p.title.align = "center" p.title.text_color = text_color p.title.text_font_size = title_font_size p.title.text_font = font p.title.text_font_style = font_style # configure axis labels p.axis.axis_label_text_font_style = font_style p.axis.axis_label_text_font_size = axis_label_font_size p.axis.axis_label_text_font = font p.axis.axis_label_text_color = text_color p.axis.major_label_text_font_size = '10pt' p.axis.major_label_text_font = font p.axis.major_label_text_color = text_color p.axis.axis_line_color = text_color # turning off minor labels p.xaxis.minor_tick_line_color = None p.yaxis.minor_tick_line_color = None return p
def _plot_kde_bokeh( density, lower, upper, density_q, xmin, xmax, ymin, ymax, gridsize, values, values2=None, rug=False, label=None, quantiles=None, rotated=False, contour=True, fill_last=True, plot_kwargs=None, fill_kwargs=None, rug_kwargs=None, contour_kwargs=None, contourf_kwargs=None, pcolormesh_kwargs=None, ax=None, legend=True, show=True, ): if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] ax = bkp.figure( width=rcParams["plot.bokeh.figure.width"], height=rcParams["plot.bokeh.figure.height"], output_backend=output_backend, tools=tools, ) if legend and label is not None: plot_kwargs["legend_label"] = label if values2 is None: if plot_kwargs is None: plot_kwargs = {} plot_kwargs.setdefault( "line_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0]) if fill_kwargs is None: fill_kwargs = {} fill_kwargs.setdefault( "fill_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0]) if rug: if rug_kwargs is None: rug_kwargs = {} rug_kwargs = rug_kwargs.copy() if "cds" in rug_kwargs: cds_rug = rug_kwargs.pop("cds") rug_varname = rug_kwargs.pop("y", "y") else: rug_varname = "y" cds_rug = ColumnDataSource({rug_varname: np.asarray(values)}) rug_kwargs.setdefault("size", 8) rug_kwargs.setdefault("line_color", plot_kwargs["line_color"]) rug_kwargs.setdefault("line_width", 1) rug_kwargs.setdefault("line_alpha", 0.35) if not rotated: rug_kwargs.setdefault("angle", np.pi / 2) if isinstance(cds_rug, dict): for _cds_rug in cds_rug.values(): if not rotated: glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs) else: glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs) ax.add_glyph(_cds_rug, glyph) else: if not rotated: glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs) else: glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs) ax.add_glyph(cds_rug, glyph) x = np.linspace(lower, upper, len(density)) if quantiles is not None: fill_kwargs.setdefault("fill_alpha", 0.75) fill_kwargs.setdefault("line_color", None) quantiles = sorted(np.clip(quantiles, 0, 1)) if quantiles[0] != 0: quantiles = [0] + quantiles if quantiles[-1] != 1: quantiles = quantiles + [1] for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]): idx = (density_q > quant_0) & (density_q < quant_1) if idx.sum(): patch_x = np.concatenate( (x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]])) patch_y = np.concatenate( (np.zeros_like(density[idx]), [density[idx][-1]], density[idx][::-1], [0])) if not rotated: ax.patch(patch_x, patch_y, **fill_kwargs) else: ax.patch(patch_y, patch_x, **fill_kwargs) else: if fill_kwargs.get("fill_alpha", False): patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]])) patch_y = np.concatenate((np.zeros_like(density), [density[-1]], density[::-1], [0])) if not rotated: ax.patch(patch_x, patch_y, **fill_kwargs) else: ax.patch(patch_y, patch_x, **fill_kwargs) if not rotated: ax.line(x, density, **plot_kwargs) else: ax.line(density, x, **plot_kwargs) else: if contour_kwargs is None: contour_kwargs = {} if contourf_kwargs is None: contourf_kwargs = {} if pcolormesh_kwargs is None: pcolormesh_kwargs = {} g_s = complex(gridsize[0]) x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s] if contour: scaled_density, *scaled_density_args = _scale_axis(density) contour_generator = _contour.QuadContourGenerator( x_x, y_y, scaled_density, None, True, 0) if "levels" in contour_kwargs: levels = contour_kwargs.get("levels") elif "levels" in contourf_kwargs: levels = contourf_kwargs.get("levels") else: levels = 11 if isinstance(levels, Integral): levels_scaled = np.linspace(0, 1, levels) levels = _rescale_axis(levels_scaled, scaled_density_args) else: levels_scaled_nonclip = _scale_axis(np.asarray(levels), scaled_density_args) levels_scaled = np.clip(levels_scaled_nonclip, 0, 1) cmap = contourf_kwargs.pop("cmap", "viridis") if isinstance(cmap, str): cmap = get_cmap(cmap) if isinstance(cmap, Callable): colors = [ rgb2hex(item) for item in cmap(np.linspace(0, 1, len(levels_scaled) + 1)) ] else: colors = cmap contour_kwargs.update(contourf_kwargs) contour_kwargs.setdefault("line_color", "black") contour_kwargs.setdefault("line_alpha", 0.25) contour_kwargs.setdefault("fill_alpha", 1) for i, (level, level_upper, color) in enumerate( zip(levels_scaled[:-1], levels_scaled[1:], colors[1:])): if not fill_last and (i == 0): continue vertices, _ = contour_generator.create_filled_contour( level, level_upper) for seg in vertices: ax.patch(*seg.T, fill_color=color, **contour_kwargs) if fill_last: ax.background_fill_color = colors[0] ax.xgrid.grid_line_color = None ax.ygrid.grid_line_color = None ax.x_range = Range1d(xmin, xmax) ax.y_range = Range1d(ymin, ymax) else: cmap = pcolormesh_kwargs.pop("cmap", "viridis") if isinstance(cmap, str): cmap = get_cmap(cmap) if isinstance(cmap, Callable): colors = [ rgb2hex(item) for item in cmap(np.linspace(0, 1, 256)) ] else: colors = cmap ax.image(image=[density.T], x=xmin, y=ymin, dw=(xmax - xmin) / density.shape[0], dh=(ymax - ymin) / density.shape[1], palette=colors, **pcolormesh_kwargs) ax.x_range.range_padding = ax.y_range.range_padding = 0 if show: bkp.show(ax, toolbar_location="above") return ax
def _plot_ess( ax, plotters, xdata, ess_tail_dataset, mean_ess, sd_ess, idata, data, text_x, text_va, kind, extra_methods, rows, cols, figsize, kwargs, extra_kwargs, text_kwargs, _linewidth, _markersize, n_samples, relative, min_ess, xt_labelsize, titlesize, ax_labelsize, ylabel, rug, rug_kind, rug_kwargs, hline_kwargs, show, ): if ax is None: _, ax = _create_axes_grid( len(plotters), rows, cols, figsize=figsize, squeeze=False, constrained_layout=True, backend="bokeh", ) for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)): ax_.circle(np.asarray(xdata), np.asarray(x), size=6) if kind == "evolution": ax_.line(np.asarray(xdata), np.asarray(x), legend_label="bulk") ess_tail = ess_tail_dataset[var_name].sel(**selection) ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange", legend_label="tail") ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange") elif rug: if rug_kwargs is None: rug_kwargs = {} if not hasattr(idata, "sample_stats"): raise ValueError("InferenceData object must contain sample_stats for rug plot") if not hasattr(idata.sample_stats, rug_kind): raise ValueError("InferenceData does not contain {} data".format(rug_kind)) rug_kwargs.setdefault("space", 0.1) _rug_kwargs = {} _rug_kwargs.setdefault("size", 8) _rug_kwargs.setdefault("line_color", rug_kwargs.get("line_color", "black")) _rug_kwargs.setdefault("line_width", 1) _rug_kwargs.setdefault("line_alpha", 0.35) _rug_kwargs.setdefault("angle", np.pi / 2) values = data[var_name].sel(**selection).values.flatten() mask = idata.sample_stats[rug_kind].values.flatten() values = rankdata(values)[mask] rug_space = np.max(x) * rug_kwargs.pop("space") rug_x, rug_y = values / (len(mask) - 1), np.zeros_like(values) - rug_space glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs) cds_rug = ColumnDataSource({"rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y)}) ax_.add_glyph(cds_rug, glyph) hline = Span( location=0, dimension="width", line_color="black", line_width=_linewidth, line_alpha=0.7, ) ax_.renderers.append(hline) if extra_methods: mean_ess_i = mean_ess[var_name].sel(**selection).values.item() sd_ess_i = sd_ess[var_name].sel(**selection).values.item() hline = Span( location=mean_ess_i, dimension="width", line_color="black", line_width=2, line_dash="dashed", line_alpha=1.0, ) ax_.renderers.append(hline) hline = Span( location=sd_ess_i, dimension="width", line_color="black", line_width=1, line_dash="dashed", line_alpha=1.0, ) ax_.renderers.append(hline) hline = Span( location=400 / n_samples if relative else min_ess, dimension="width", line_color="red", line_width=3, line_dash="dashed", line_alpha=1.0, ) ax_.renderers.append(hline) title = Title() title.text = make_label(var_name, selection) ax_.title = title ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile" ax_.yaxis.axis_label = ylabel.format("Relative ESS" if relative else "ESS") if show: grid = gridplot([list(item) for item in ax], toolbar_location="above") bkp.show(grid) return ax
def _plot_kde_bokeh( density, lower, upper, density_q, xmin, xmax, ymin, ymax, gridsize, values, values2=None, rug=False, label=None, bw=4.5, quantiles=None, rotated=False, contour=True, fill_last=True, textsize=None, plot_kwargs=None, fill_kwargs=None, rug_kwargs=None, contour_kwargs=None, contourf_kwargs=None, pcolormesh_kwargs=None, ax=None, legend=True, show=True, ): if ax is None: ax = bkp.figure(sizing_mode="stretch_both") if legend and label is not None: plot_kwargs["legend_label"] = label if values2 is None: if plot_kwargs is None: plot_kwargs = {} plot_kwargs.setdefault( "line_color", plt.rcParams["axes.prop_cycle"].by_key()["color"][0]) default_color = plot_kwargs.get("color") if fill_kwargs is None: fill_kwargs = {} fill_kwargs.setdefault("color", default_color) if rug: if rug_kwargs is None: rug_kwargs = {} rug_kwargs = rug_kwargs.copy() if "cds" in rug_kwargs: cds_rug = rug_kwargs.pop("cds") rug_varname = rug_kwargs.pop("y", "y") else: rug_varname = "y" cds_rug = ColumnDataSource({rug_varname: values}) rug_kwargs.setdefault("size", 8) rug_kwargs.setdefault("line_color", plot_kwargs["line_color"]) rug_kwargs.setdefault("line_width", 1) rug_kwargs.setdefault("line_alpha", 0.35) rug_kwargs.setdefault("angle", np.pi / 2) if isinstance(cds_rug, dict): for _cds_rug in cds_rug.values(): glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs) ax.add_glyph(_cds_rug, glyph) else: glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs) ax.add_glyph(cds_rug, glyph) x = np.linspace(lower, upper, len(density)) ax.line(x, density, **plot_kwargs) else: # todo raise NotImplementedError("Use matplotlib backend") if show: bkp.show(ax) return ax
def plot_mcse( ax, plotters, length_plotters, rows, cols, figsize, errorbar, rug, data, probs, kwargs, # pylint: disable=unused-argument extra_methods, mean_mcse, sd_mcse, textsize, labeller, text_kwargs, # pylint: disable=unused-argument rug_kwargs, extra_kwargs, idata, rug_kind, backend_kwargs, show, ): """Bokeh mcse plot.""" if backend_kwargs is None: backend_kwargs = {} backend_kwargs = { **backend_kwarg_defaults(), **backend_kwargs, } (figsize, *_, _linewidth, _markersize) = _scale_fig_size(figsize, textsize, rows, cols) extra_kwargs = {} if extra_kwargs is None else extra_kwargs extra_kwargs.setdefault("linewidth", _linewidth / 2) extra_kwargs.setdefault("color", "black") extra_kwargs.setdefault("alpha", 0.5) if ax is None: ax = create_axes_grid( length_plotters, rows, cols, figsize=figsize, backend_kwargs=backend_kwargs, ) else: ax = np.atleast_2d(ax) for (var_name, selection, isel, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): if errorbar or rug: values = data[var_name].sel(**selection).values.flatten() if errorbar: quantile_values = _quantile(values, probs) ax_.dash(probs, quantile_values) ax_.multi_line( list(zip(probs, probs)), [(quant - err, quant + err) for quant, err in zip(quantile_values, x)], ) else: ax_.circle(probs, x) if extra_methods: mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item() sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item() hline_mean = Span( location=mean_mcse_i, dimension="width", line_color=extra_kwargs["color"], line_width=extra_kwargs["linewidth"] * 2, line_alpha=extra_kwargs["alpha"], ) ax_.renderers.append(hline_mean) hline_sd = Span( location=sd_mcse_i, dimension="width", line_color="black", line_width=extra_kwargs["linewidth"], line_alpha=extra_kwargs["alpha"], ) ax_.renderers.append(hline_sd) if rug: if rug_kwargs is None: rug_kwargs = {} if not hasattr(idata, "sample_stats"): raise ValueError("InferenceData object must contain sample_stats for rug plot") if not hasattr(idata.sample_stats, rug_kind): raise ValueError("InferenceData does not contain {} data".format(rug_kind)) rug_kwargs.setdefault("space", 0.1) _rug_kwargs = {} _rug_kwargs.setdefault("size", 8) _rug_kwargs.setdefault("line_color", rug_kwargs.get("line_color", "black")) _rug_kwargs.setdefault("line_width", 1) _rug_kwargs.setdefault("line_alpha", 0.35) _rug_kwargs.setdefault("angle", np.pi / 2) mask = idata.sample_stats[rug_kind].values.flatten() values = rankdata(values, method="average")[mask] if errorbar: rug_x, rug_y = ( values / (len(mask) - 1), np.full_like( values, min( 0, min(quantile_values) - (max(quantile_values) - min(quantile_values)) * 0.05, ), ), ) hline = Span( location=min( 0, min(quantile_values) - (max(quantile_values) - min(quantile_values)) * 0.05, ), dimension="width", line_color="black", line_width=_linewidth, line_alpha=0.7, ) else: rug_x, rug_y = ( values / (len(mask) - 1), np.full_like( values, 0, ), ) hline = Span( location=0, dimension="width", line_color="black", line_width=_linewidth, line_alpha=0.7, ) ax_.renderers.append(hline) glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs) cds_rug = ColumnDataSource({"rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y)}) ax_.add_glyph(cds_rug, glyph) title = Title() title.text = labeller.make_label_vert(var_name, selection, isel) ax_.title = title ax_.xaxis.axis_label = "Quantile" ax_.yaxis.axis_label = ( r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles" ) if not errorbar: ax_.y_range._property_values["start"] = -0.05 # pylint: disable=protected-access ax_.y_range._property_values["end"] = 1 # pylint: disable=protected-access show_layout(ax, show) return ax
def _plot_trace_bokeh( data, var_names=None, coords=None, divergences="bottom", figsize=None, rug=False, lines=None, compact=False, combined=False, legend=False, plot_kwargs=None, fill_kwargs=None, rug_kwargs=None, hist_kwargs=None, trace_kwargs=None, backend_kwargs=None, show=True, ): if divergences: try: divergence_data = convert_to_dataset( data, group="sample_stats").diverging except (ValueError, AttributeError): # No sample_stats, or no `.diverging` divergences = False if coords is None: coords = {} data = get_coords(convert_to_dataset(data, group="posterior"), coords) var_names = _var_names(var_names, data) if divergences: divergence_data = get_coords( divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")}) if lines is None: lines = () num_colors = len(data.chain) + 1 if combined else len(data.chain) colors = [ prop for _, prop in zip( range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])) ] if compact: skip_dims = set(data.dims) - {"chain", "draw"} else: skip_dims = set() plotters = list( xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) max_plots = rcParams["plot.max_subplots"] max_plots = len(plotters) if max_plots is None else max_plots if len(plotters) > max_plots: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " "of variables to plot ({len_plotters}), generating only {max_plots} " "plots".format(max_plots=max_plots, len_plotters=len(plotters)), SyntaxWarning, ) plotters = plotters[:max_plots] if figsize is None: figsize = (12, len(plotters) * 2) if trace_kwargs is None: trace_kwargs = {} trace_kwargs.setdefault("alpha", 0.35) if hist_kwargs is None: hist_kwargs = {} if plot_kwargs is None: plot_kwargs = {} if fill_kwargs is None: fill_kwargs = {} if rug_kwargs is None: rug_kwargs = {} hist_kwargs.setdefault("alpha", 0.35) figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) trace_kwargs.setdefault("line_width", linewidth) plot_kwargs.setdefault("line_width", linewidth) if backend_kwargs is None: backend_kwargs = dict() backend_kwargs.setdefault("tools", rcParams["plot.bokeh.tools"]) backend_kwargs.setdefault("output_backend", rcParams["plot.bokeh.output_backend"]) backend_kwargs.setdefault( "height", int(figsize[1] * rcParams["plot.bokeh.figure.dpi"] // len(plotters))) backend_kwargs.setdefault( "width", int(figsize[0] * rcParams["plot.bokeh.figure.dpi"] // 2)) axes = [] for i in range(len(plotters)): if i != 0: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs), ] else: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs) ] axes.append(_axes) axes = np.array(axes) cds_data = {} cds_var_groups = {} draw_name = "draw" for var_name, selection, value in list( xarray_var_iter(data, var_names=var_names, combined=True)): if selection: cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ([key, value] if ( isinstance(value, str) or not isinstance(value, Iterable)) else [key, *value])), ) else: cds_name = var_name if var_name not in cds_var_groups: cds_var_groups[var_name] = [] cds_var_groups[var_name].append(cds_name) for chain_idx, _ in enumerate(data.chain.values): if chain_idx not in cds_data: cds_data[chain_idx] = {} _data = value[chain_idx] cds_data[chain_idx][cds_name] = _data while any(key == draw_name for key in cds_data[0]): draw_name += "w" for chain_idx in cds_data: cds_data[chain_idx][draw_name] = data.draw.values cds_data = { chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items() } for idx, (var_name, selection, value) in enumerate(plotters): value = np.atleast_2d(value) if len(value.shape) == 2: y_name = (var_name if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ((key, value) if ( isinstance(value, str) or not isinstance(value, Iterable)) else ( key, *value))), )) if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, colors=colors, combined=combined, rug=rug, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, ) else: for y_name in cds_var_groups[var_name]: if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, colors=colors, combined=combined, rug=rug, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, ) for col in (0, 1): _title = Title() _title.text = make_label(var_name, selection) axes[idx, col].title = _title for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection): if isinstance(vlines, (float, int)): line_values = [vlines] else: line_values = np.atleast_1d(vlines).ravel() for line_value in line_values: vline = Span( location=line_value, dimension="height", line_color="black", line_width=1.5, line_alpha=0.75, ) hline = Span( location=line_value, dimension="width", line_color="black", line_width=1.5, line_alpha=trace_kwargs["alpha"], ) axes[idx, 0].renderers.append(vline) axes[idx, 1].renderers.append(hline) if legend: for col in (0, 1): axes[idx, col].legend.location = "top_left" axes[idx, col].legend.click_policy = "hide" else: for col in (0, 1): if axes[idx, col].legend: axes[idx, col].legend.visible = False if divergences: div_density_kwargs = {} div_density_kwargs.setdefault("size", 14) div_density_kwargs.setdefault("line_color", "red") div_density_kwargs.setdefault("line_width", 2) div_density_kwargs.setdefault("line_alpha", 0.50) div_density_kwargs.setdefault("angle", np.pi / 2) div_trace_kwargs = {} div_trace_kwargs.setdefault("size", 14) div_trace_kwargs.setdefault("line_color", "red") div_trace_kwargs.setdefault("line_width", 2) div_trace_kwargs.setdefault("line_alpha", 0.50) div_trace_kwargs.setdefault("angle", np.pi / 2) div_selection = { k: v for k, v in selection.items() if k in divergence_data.dims } divs = divergence_data.sel(**div_selection).values divs = np.atleast_2d(divs) for chain, chain_divs in enumerate(divs): div_idxs = np.arange(len(chain_divs))[chain_divs] if div_idxs.size > 0: values = value[chain, div_idxs] tmp_cds = ColumnDataSource({"y": values, "x": div_idxs}) if divergences == "top": y_div_trace = value.max() else: y_div_trace = value.min() glyph_density = Dash(x="y", y=0.0, **div_density_kwargs) glyph_trace = Dash(x="x", y=y_div_trace, **div_trace_kwargs) axes[idx, 0].add_glyph(tmp_cds, glyph_density) axes[idx, 1].add_glyph(tmp_cds, glyph_trace) if show: grid = gridplot([list(item) for item in axes], toolbar_location="above") bkp.show(grid) return axes
def plot_trace( data, var_names, divergences, kind, figsize, rug, lines, circ_var_names, # pylint: disable=unused-argument circ_var_units, # pylint: disable=unused-argument compact, compact_prop, combined, chain_prop, legend, labeller, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, plotters, divergence_data, axes, backend_kwargs, backend_config, show, ): """Bokeh traceplot.""" # If divergences are plotted they must be provided if divergences is not False: assert divergence_data is not None if backend_config is None: backend_config = {} backend_config = { **backend_kwarg_defaults(("bounds_y_range", "plot.bokeh.bounds_y_range"), ), **backend_config, } # Set plot default backend kwargs if backend_kwargs is None: backend_kwargs = {} backend_kwargs = { **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"), ), **backend_kwargs, } dpi = backend_kwargs.pop("dpi") if figsize is None: figsize = (12, len(plotters) * 2) figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) backend_kwargs.setdefault("height", int(figsize[1] * dpi // len(plotters))) backend_kwargs.setdefault("width", int(figsize[0] * dpi // 2)) if lines is None: lines = () num_chain_props = len(data.chain) + 1 if combined else len(data.chain) if not compact: chain_prop = ({ "line_color": plt.rcParams["axes.prop_cycle"].by_key()["color"] } if chain_prop is None else chain_prop) else: chain_prop = ({ "line_dash": ("solid", "dotted", "dashed", "dashdot"), } if chain_prop is None else chain_prop) compact_prop = ({ "line_color": plt.rcParams["axes.prop_cycle"].by_key()["color"] } if compact_prop is None else compact_prop) if isinstance(chain_prop, str): chain_prop = { chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop] } if isinstance(chain_prop, tuple): warnings.warn( "chain_prop as a tuple will be deprecated in a future warning, use a dict instead", FutureWarning, ) chain_prop = {chain_prop[0]: chain_prop[1]} chain_prop = { prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))] for prop_name, props in chain_prop.items() } if isinstance(compact_prop, str): compact_prop = { compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop] } if isinstance(compact_prop, tuple): warnings.warn( "compact_prop as a tuple will be deprecated in a future warning, use a dict instead", FutureWarning, ) compact_prop = {compact_prop[0]: compact_prop[1]} trace_kwargs = {} if trace_kwargs is None else trace_kwargs trace_kwargs.setdefault("alpha", 0.35) if hist_kwargs is None: hist_kwargs = {} hist_kwargs.setdefault("alpha", 0.35) if plot_kwargs is None: plot_kwargs = {} if fill_kwargs is None: fill_kwargs = {} if rug_kwargs is None: rug_kwargs = {} if rank_kwargs is None: rank_kwargs = {} trace_kwargs.setdefault("line_width", linewidth) plot_kwargs.setdefault("line_width", linewidth) if rank_kwargs is None: rank_kwargs = {} if axes is None: axes = [] backend_kwargs_copy = backend_kwargs.copy() for i in range(len(plotters)): if not i: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs_copy) ] backend_kwargs_copy.setdefault("x_range", _axes[1].x_range) else: _axes = [ bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs_copy), ] axes.append(_axes) axes = np.atleast_2d(axes) cds_data = {} cds_var_groups = {} draw_name = "draw" for var_name, selection, isel, value in list( xarray_var_iter(data, var_names=var_names, combined=True)): if selection: cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ([key, value] if ( isinstance(value, str) or not isinstance(value, Iterable)) else [key, *value])), ) else: cds_name = var_name if var_name not in cds_var_groups: cds_var_groups[var_name] = [] cds_var_groups[var_name].append(cds_name) for chain_idx, _ in enumerate(data.chain.values): if chain_idx not in cds_data: cds_data[chain_idx] = {} _data = value[chain_idx] cds_data[chain_idx][cds_name] = _data while any(key == draw_name for key in cds_data[0]): draw_name += "w" for chain_idx in cds_data: cds_data[chain_idx][draw_name] = data.draw.values cds_data = { chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items() } for idx, (var_name, selection, isel, value) in enumerate(plotters): value = np.atleast_2d(value) if len(value.shape) == 2: y_name = (var_name if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format( var_name, "_".join( str(item) for key, value in selection.items() for item in ((key, value) if ( isinstance(value, str) or not isinstance(value, Iterable)) else ( key, *value))), )) if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, chain_prop=chain_prop, combined=combined, rug=rug, kind=kind, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, rank_kwargs=rank_kwargs, ) else: for y_name in cds_var_groups[var_name]: if rug: rug_kwargs["y"] = y_name _plot_chains_bokeh( ax_density=axes[idx, 0], ax_trace=axes[idx, 1], data=cds_data, x_name=draw_name, y_name=y_name, chain_prop=chain_prop, combined=combined, rug=rug, kind=kind, legend=legend, trace_kwargs=trace_kwargs, hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, rank_kwargs=rank_kwargs, ) for col in (0, 1): _title = Title() _title.text = labeller.make_label_vert(var_name, selection, isel) axes[idx, col].title = _title axes[idx, col].y_range = DataRange1d( bounds=backend_config["bounds_y_range"], min_interval=0.1) for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection): if isinstance(vlines, (float, int)): line_values = [vlines] else: line_values = np.atleast_1d(vlines).ravel() for line_value in line_values: vline = Span( location=line_value, dimension="height", line_color="black", line_width=1.5, line_alpha=0.75, ) hline = Span( location=line_value, dimension="width", line_color="black", line_width=1.5, line_alpha=trace_kwargs["alpha"], ) axes[idx, 0].renderers.append(vline) axes[idx, 1].renderers.append(hline) if legend: for col in (0, 1): axes[idx, col].legend.location = "top_left" axes[idx, col].legend.click_policy = "hide" else: for col in (0, 1): if axes[idx, col].legend: axes[idx, col].legend.visible = False if divergences: div_density_kwargs = {} div_density_kwargs.setdefault("size", 14) div_density_kwargs.setdefault("line_color", "red") div_density_kwargs.setdefault("line_width", 2) div_density_kwargs.setdefault("line_alpha", 0.50) div_density_kwargs.setdefault("angle", np.pi / 2) div_trace_kwargs = {} div_trace_kwargs.setdefault("size", 14) div_trace_kwargs.setdefault("line_color", "red") div_trace_kwargs.setdefault("line_width", 2) div_trace_kwargs.setdefault("line_alpha", 0.50) div_trace_kwargs.setdefault("angle", np.pi / 2) div_selection = { k: v for k, v in selection.items() if k in divergence_data.dims } divs = divergence_data.sel(**div_selection).values divs = np.atleast_2d(divs) for chain, chain_divs in enumerate(divs): div_idxs = np.arange(len(chain_divs))[chain_divs] if div_idxs.size > 0: values = value[chain, div_idxs] tmp_cds = ColumnDataSource({"y": values, "x": div_idxs}) if divergences == "top": y_div_trace = value.max() else: y_div_trace = value.min() glyph_density = Dash(x="y", y=0.0, **div_density_kwargs) if kind == "trace": glyph_trace = Dash(x="x", y=y_div_trace, **div_trace_kwargs) axes[idx, 1].add_glyph(tmp_cds, glyph_trace) axes[idx, 0].add_glyph(tmp_cds, glyph_density) show_layout(axes, show) return axes