コード例 #1
0
def _plot_mcse(
    ax,
    plotters,
    length_plotters,
    rows,
    cols,
    figsize,
    errorbar,
    rug,
    data,
    probs,
    extra_kwargs,
    extra_methods,
    mean_mcse,
    sd_mcse,
    rug_kwargs,
    idata,
    rug_kind,
    _markersize,
    _linewidth,
    show,
):
    if ax is None:
        _, ax = _create_axes_grid(length_plotters,
                                  rows,
                                  cols,
                                  figsize=figsize,
                                  backend="bokeh")

    for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)):
        if errorbar or rug:
            values = data[var_name].sel(**selection).values.flatten()
        if errorbar:
            quantile_values = _quantile(values, probs)
            ax_.dash(probs, quantile_values)
            ax_.multi_line(
                list(zip(probs, probs)),
                [(quant - err, quant + err)
                 for quant, err in zip(quantile_values, x)],
            )
        else:
            ax_.circle(probs, x)
            if extra_methods:
                mean_mcse_i = mean_mcse[var_name].sel(
                    **selection).values.item()
                sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
                hline_mean = Span(
                    location=mean_mcse_i,
                    dimension="width",
                    line_color="black",
                    line_width=extra_kwargs["linewidth"] * 2,
                    line_alpha=extra_kwargs["alpha"],
                )

                ax_.renderers.append(hline_mean)

                hline_sd = Span(
                    location=sd_mcse_i,
                    dimension="width",
                    line_color="black",
                    line_width=extra_kwargs["linewidth"],
                    line_alpha=extra_kwargs["alpha"],
                )

                ax_.renderers.append(hline_sd)

        if rug:
            if rug_kwargs is None:
                rug_kwargs = {}
            if not hasattr(idata, "sample_stats"):
                raise ValueError(
                    "InferenceData object must contain sample_stats for rug plot"
                )
            if not hasattr(idata.sample_stats, rug_kind):
                raise ValueError(
                    "InferenceData does not contain {} data".format(rug_kind))
            rug_kwargs.setdefault("space", 0.1)

            _rug_kwargs = {}
            _rug_kwargs.setdefault("size", 8)
            _rug_kwargs.setdefault("line_color",
                                   rug_kwargs.get("line_color", "black"))
            _rug_kwargs.setdefault("line_width", 1)
            _rug_kwargs.setdefault("line_alpha", 0.35)
            _rug_kwargs.setdefault("angle", np.pi / 2)

            mask = idata.sample_stats[rug_kind].values.flatten()
            values = rankdata(values)[mask]
            if errorbar:
                rug_x, rug_y = (
                    values / (len(mask) - 1),
                    np.full_like(
                        values,
                        min(
                            0,
                            min(quantile_values) -
                            (max(quantile_values) - min(quantile_values)) *
                            0.05,
                        ),
                    ),
                )

                hline = Span(
                    location=min(
                        0,
                        min(quantile_values) -
                        (max(quantile_values) - min(quantile_values)) * 0.05,
                    ),
                    dimension="width",
                    line_color="black",
                    line_width=_linewidth,
                    line_alpha=0.7,
                )

            else:
                rug_x, rug_y = (
                    values / (len(mask) - 1),
                    np.full_like(
                        values,
                        0,
                    ),
                )

                hline = Span(
                    location=0,
                    dimension="width",
                    line_color="black",
                    line_width=_linewidth,
                    line_alpha=0.7,
                )

            ax_.renderers.append(hline)

            glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs)
            cds_rug = ColumnDataSource({
                "rug_x": np.asarray(rug_x),
                "rug_y": np.asarray(rug_y)
            })
            ax_.add_glyph(cds_rug, glyph)

        title = Title()
        title.text = make_label(var_name, selection)
        ax_.title = title

        ax_.xaxis.axis_label = "Quantile"
        ax_.yaxis.axis_label = (r"Value $\pm$ MCSE for quantiles"
                                if errorbar else "MCSE for quantiles")

        if not errorbar:
            ax_.y_range._property_values["start"] = -0.05  # pylint: disable=protected-access
            ax_.y_range._property_values["end"] = 1  # pylint: disable=protected-access

    if show:
        grid = gridplot([list(item) for item in ax], toolbar_location="above")
        bkp.show(grid)

    return ax
コード例 #2
0
ファイル: traceplot.py プロジェクト: waternk/arviz
def plot_trace(
    data,
    var_names,
    divergences,
    kind,
    figsize,
    rug,
    lines,
    combined,
    chain_prop,
    legend,
    plot_kwargs: [Dict],
    fill_kwargs: [Dict],
    rug_kwargs: [Dict],
    hist_kwargs: [Dict],
    trace_kwargs: [Dict],
    rank_kwargs: [Dict],
    plotters,
    divergence_data,
    axes,
    backend_config,
    backend_kwargs: [Dict],
    show,
):
    """Bokeh traceplot."""
    # If divergences are plotted they must be provided
    if divergences is not False:
        assert divergence_data is not None

    if backend_config is None:
        backend_config = {}

    backend_config = {
        **backend_kwarg_defaults(("bounds_y_range", "plot.bokeh.bounds_y_range"), ),
        **backend_config,
    }

    # Set plot default backend kwargs
    if backend_kwargs is None:
        backend_kwargs = {}

    backend_kwargs = {
        **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"), ),
        **backend_kwargs,
    }
    dpi = backend_kwargs.pop("dpi")

    backend_kwargs.setdefault("height", int(figsize[1] * dpi // len(plotters)))
    backend_kwargs.setdefault("width", int(figsize[0] * dpi // 2))

    figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize,
                                                     10,
                                                     rows=len(plotters),
                                                     cols=2)

    trace_kwargs.setdefault("line_width", linewidth)
    plot_kwargs.setdefault("line_width", linewidth)

    if rank_kwargs is None:
        rank_kwargs = {}

    if axes is None:
        axes = []
        for i in range(len(plotters)):
            if i != 0:
                _axes = [
                    bkp.figure(**backend_kwargs),
                    bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs),
                ]
            else:
                _axes = [
                    bkp.figure(**backend_kwargs),
                    bkp.figure(**backend_kwargs)
                ]
            axes.append(_axes)

    axes = np.atleast_2d(axes)

    cds_data = {}
    cds_var_groups = {}
    draw_name = "draw"

    for var_name, selection, value in list(
            xarray_var_iter(data, var_names=var_names, combined=True)):
        if selection:
            cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format(
                var_name,
                "_".join(
                    str(item) for key, value in selection.items()
                    for item in ([key, value] if (
                        isinstance(value, str) or
                        not isinstance(value, Iterable)) else [key, *value])),
            )
        else:
            cds_name = var_name

        if var_name not in cds_var_groups:
            cds_var_groups[var_name] = []
        cds_var_groups[var_name].append(cds_name)

        for chain_idx, _ in enumerate(data.chain.values):
            if chain_idx not in cds_data:
                cds_data[chain_idx] = {}
            _data = value[chain_idx]
            cds_data[chain_idx][cds_name] = _data

    while any(key == draw_name for key in cds_data[0]):
        draw_name += "w"

    for chain_idx in cds_data:
        cds_data[chain_idx][draw_name] = data.draw.values

    cds_data = {
        chain_idx: ColumnDataSource(cds)
        for chain_idx, cds in cds_data.items()
    }

    for idx, (var_name, selection, value) in enumerate(plotters):
        value = np.atleast_2d(value)

        if len(value.shape) == 2:
            y_name = (var_name
                      if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format(
                          var_name,
                          "_".join(
                              str(item) for key, value in selection.items()
                              for item in ((key, value) if (
                                  isinstance(value, str)
                                  or not isinstance(value, Iterable)) else (
                                      key, *value))),
                      ))
            if rug:
                rug_kwargs["y"] = y_name
            _plot_chains_bokeh(
                ax_density=axes[idx, 0],
                ax_trace=axes[idx, 1],
                data=cds_data,
                x_name=draw_name,
                y_name=y_name,
                chain_prop=chain_prop,
                combined=combined,
                rug=rug,
                kind=kind,
                legend=legend,
                trace_kwargs=trace_kwargs,
                hist_kwargs=hist_kwargs,
                plot_kwargs=plot_kwargs,
                fill_kwargs=fill_kwargs,
                rug_kwargs=rug_kwargs,
                rank_kwargs=rank_kwargs,
            )
        else:
            for y_name in cds_var_groups[var_name]:
                if rug:
                    rug_kwargs["y"] = y_name
                _plot_chains_bokeh(
                    ax_density=axes[idx, 0],
                    ax_trace=axes[idx, 1],
                    data=cds_data,
                    x_name=draw_name,
                    y_name=y_name,
                    chain_prop=chain_prop,
                    combined=combined,
                    rug=rug,
                    kind=kind,
                    legend=legend,
                    trace_kwargs=trace_kwargs,
                    hist_kwargs=hist_kwargs,
                    plot_kwargs=plot_kwargs,
                    fill_kwargs=fill_kwargs,
                    rug_kwargs=rug_kwargs,
                    rank_kwargs=rank_kwargs,
                )

        for col in (0, 1):
            _title = Title()
            _title.text = make_label(var_name, selection)
            axes[idx, col].title = _title
            axes[idx, col].y_range = DataRange1d(
                bounds=backend_config["bounds_y_range"], min_interval=0.1)

        for _, _, vlines in (j for j in lines
                             if j[0] == var_name and j[1] == selection):
            if isinstance(vlines, (float, int)):
                line_values = [vlines]
            else:
                line_values = np.atleast_1d(vlines).ravel()

            for line_value in line_values:
                vline = Span(
                    location=line_value,
                    dimension="height",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=0.75,
                )
                hline = Span(
                    location=line_value,
                    dimension="width",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=trace_kwargs["alpha"],
                )

                axes[idx, 0].renderers.append(vline)
                axes[idx, 1].renderers.append(hline)

        if legend:
            for col in (0, 1):
                axes[idx, col].legend.location = "top_left"
                axes[idx, col].legend.click_policy = "hide"
        else:
            for col in (0, 1):
                if axes[idx, col].legend:
                    axes[idx, col].legend.visible = False

        if divergences:
            div_density_kwargs = {}
            div_density_kwargs.setdefault("size", 14)
            div_density_kwargs.setdefault("line_color", "red")
            div_density_kwargs.setdefault("line_width", 2)
            div_density_kwargs.setdefault("line_alpha", 0.50)
            div_density_kwargs.setdefault("angle", np.pi / 2)

            div_trace_kwargs = {}
            div_trace_kwargs.setdefault("size", 14)
            div_trace_kwargs.setdefault("line_color", "red")
            div_trace_kwargs.setdefault("line_width", 2)
            div_trace_kwargs.setdefault("line_alpha", 0.50)
            div_trace_kwargs.setdefault("angle", np.pi / 2)

            div_selection = {
                k: v
                for k, v in selection.items() if k in divergence_data.dims
            }
            divs = divergence_data.sel(**div_selection).values
            divs = np.atleast_2d(divs)

            for chain, chain_divs in enumerate(divs):
                div_idxs = np.arange(len(chain_divs))[chain_divs]
                if div_idxs.size > 0:
                    values = value[chain, div_idxs]
                    tmp_cds = ColumnDataSource({"y": values, "x": div_idxs})
                    if divergences == "top":
                        y_div_trace = value.max()
                    else:
                        y_div_trace = value.min()
                    glyph_density = Dash(x="y", y=0.0, **div_density_kwargs)
                    glyph_trace = Dash(x="x",
                                       y=y_div_trace,
                                       **div_trace_kwargs)

                    axes[idx, 0].add_glyph(tmp_cds, glyph_density)
                    axes[idx, 1].add_glyph(tmp_cds, glyph_trace)

    show_layout(axes, show)

    return axes
コード例 #3
0
ファイル: kdeplot.py プロジェクト: aidinhass/arviz
def plot_kde(
    density,
    lower,
    upper,
    density_q,
    xmin,
    xmax,
    ymin,
    ymax,
    gridsize,
    values,
    values2,
    rug,
    label,  # pylint: disable=unused-argument
    quantiles,
    rotated,
    contour,
    fill_last,
    figsize,
    textsize,  # pylint: disable=unused-argument
    plot_kwargs,
    fill_kwargs,
    rug_kwargs,
    contour_kwargs,
    contourf_kwargs,
    pcolormesh_kwargs,
    is_circular,  # pylint: disable=unused-argument
    ax,
    legend,  # pylint: disable=unused-argument
    backend_kwargs,
    show,
    return_glyph,
):
    """Bokeh kde plot."""
    if backend_kwargs is None:
        backend_kwargs = {}

    backend_kwargs = {
        **backend_kwarg_defaults(),
        **backend_kwargs,
    }

    figsize, *_ = _scale_fig_size(figsize, textsize)

    if ax is None:
        ax = create_axes_grid(
            1,
            figsize=figsize,
            squeeze=True,
            backend_kwargs=backend_kwargs,
        )

    glyphs = []
    if values2 is None:
        if plot_kwargs is None:
            plot_kwargs = {}
        plot_kwargs.setdefault(
            "line_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0])

        if fill_kwargs is None:
            fill_kwargs = {}

        fill_kwargs.setdefault(
            "fill_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0])

        if rug:
            if rug_kwargs is None:
                rug_kwargs = {}

            rug_kwargs = rug_kwargs.copy()
            if "cds" in rug_kwargs:
                cds_rug = rug_kwargs.pop("cds")
                rug_varname = rug_kwargs.pop("y", "y")
            else:
                rug_varname = "y"
                cds_rug = ColumnDataSource({rug_varname: np.asarray(values)})

            rug_kwargs.setdefault("size", 8)
            rug_kwargs.setdefault("line_color", plot_kwargs["line_color"])
            rug_kwargs.setdefault("line_width", 1)
            rug_kwargs.setdefault("line_alpha", 0.35)
            if not rotated:
                rug_kwargs.setdefault("angle", np.pi / 2)
            if isinstance(cds_rug, dict):
                for _cds_rug in cds_rug.values():
                    if not rotated:
                        glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs)
                    else:
                        glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs)
                    ax.add_glyph(_cds_rug, glyph)
            else:
                if not rotated:
                    glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs)
                else:
                    glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs)
                ax.add_glyph(cds_rug, glyph)
            glyphs.append(glyph)

        x = np.linspace(lower, upper, len(density))

        if quantiles is not None:
            fill_kwargs.setdefault("fill_alpha", 0.75)
            fill_kwargs.setdefault("line_color", None)

            quantiles = sorted(np.clip(quantiles, 0, 1))
            if quantiles[0] != 0:
                quantiles = [0] + quantiles
            if quantiles[-1] != 1:
                quantiles = quantiles + [1]

            for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]):
                idx = (density_q > quant_0) & (density_q < quant_1)
                if idx.sum():
                    patch_x = np.concatenate(
                        (x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]]))
                    patch_y = np.concatenate(
                        (np.zeros_like(density[idx]), [density[idx][-1]],
                         density[idx][::-1], [0]))
                    if not rotated:
                        patch = ax.patch(patch_x, patch_y, **fill_kwargs)
                    else:
                        patch = ax.patch(patch_y, patch_x, **fill_kwargs)
                    glyphs.append(patch)
        else:
            if fill_kwargs.get("fill_alpha", False):
                patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]]))
                patch_y = np.concatenate((np.zeros_like(density),
                                          [density[-1]], density[::-1], [0]))
                if not rotated:
                    patch = ax.patch(patch_x, patch_y, **fill_kwargs)
                else:
                    patch = ax.patch(patch_y, patch_x, **fill_kwargs)
                glyphs.append(patch)

            if not rotated:
                line = ax.line(x, density, **plot_kwargs)
            else:
                line = ax.line(density, x, **plot_kwargs)
            glyphs.append(line)

    else:
        if contour_kwargs is None:
            contour_kwargs = {}
        if contourf_kwargs is None:
            contourf_kwargs = {}
        if pcolormesh_kwargs is None:
            pcolormesh_kwargs = {}

        g_s = complex(gridsize[0])
        x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s]

        if contour:

            scaled_density, *scaled_density_args = _scale_axis(density)

            contour_generator = _contour.QuadContourGenerator(
                x_x, y_y, scaled_density, None, True, 0)

            if "levels" in contour_kwargs:
                levels = contour_kwargs.get("levels")
            elif "levels" in contourf_kwargs:
                levels = contourf_kwargs.get("levels")
            else:
                levels = 11

            if isinstance(levels, Integral):
                levels_scaled = np.linspace(0, 1, levels)
                levels = _rescale_axis(levels_scaled, scaled_density_args)
            else:
                levels_scaled_nonclip = _scale_axis(np.asarray(levels),
                                                    scaled_density_args)
                levels_scaled = np.clip(levels_scaled_nonclip, 0, 1)

            cmap = contourf_kwargs.pop("cmap", "viridis")
            if isinstance(cmap, str):
                cmap = get_cmap(cmap)
            if isinstance(cmap, Callable):
                colors = [
                    rgb2hex(item)
                    for item in cmap(np.linspace(0, 1,
                                                 len(levels_scaled) + 1))
                ]
            else:
                colors = cmap

            contour_kwargs.update(contourf_kwargs)
            contour_kwargs.setdefault("line_color", "black")
            contour_kwargs.setdefault("line_alpha", 0.25)
            contour_kwargs.setdefault("fill_alpha", 1)

            for i, (level, level_upper, color) in enumerate(
                    zip(levels_scaled[:-1], levels_scaled[1:], colors[1:])):
                if not fill_last and (i == 0):
                    continue
                vertices, _ = contour_generator.create_filled_contour(
                    level, level_upper)
                for seg in vertices:
                    patch = ax.patch(*seg.T,
                                     fill_color=color,
                                     **contour_kwargs)
                    glyphs.append(patch)

            if fill_last:
                ax.background_fill_color = colors[0]

            ax.xgrid.grid_line_color = None
            ax.ygrid.grid_line_color = None

            ax.x_range = Range1d(xmin, xmax)
            ax.y_range = Range1d(ymin, ymax)

        else:

            cmap = pcolormesh_kwargs.pop("cmap", "viridis")
            if isinstance(cmap, str):
                cmap = get_cmap(cmap)
            if isinstance(cmap, Callable):
                colors = [
                    rgb2hex(item) for item in cmap(np.linspace(0, 1, 256))
                ]
            else:
                colors = cmap

            image = ax.image(image=[density.T],
                             x=xmin,
                             y=ymin,
                             dw=(xmax - xmin) / density.shape[0],
                             dh=(ymax - ymin) / density.shape[1],
                             palette=colors,
                             **pcolormesh_kwargs)
            glyphs.append(image)
            ax.x_range.range_padding = ax.y_range.range_padding = 0

    show_layout(ax, show)

    if return_glyph:
        return ax, glyphs

    return ax
コード例 #4
0
N = 9
x = np.linspace(-2, 2, N)
y = x**2
sizes = np.linspace(10, 20, N)

source = ColumnDataSource(dict(x=x, y=y, sizes=sizes))

plot = Plot(title=None,
            plot_width=300,
            plot_height=300,
            min_border=0,
            toolbar_location=None)

glyph = Dash(x="x",
             y="y",
             size="sizes",
             line_color="#3288bd",
             line_width=1,
             fill_color=None)
plot.add_glyph(source, glyph)

xaxis = LinearAxis()
plot.add_layout(xaxis, 'below')

yaxis = LinearAxis()
plot.add_layout(yaxis, 'left')

plot.add_layout(Grid(dimension=0, ticker=xaxis.ticker))
plot.add_layout(Grid(dimension=1, ticker=yaxis.ticker))

curdoc().add_root(plot)
コード例 #5
0
def _plot_trace_bokeh(
    data,
    var_names=None,
    coords=None,
    divergences="bottom",
    figsize=None,
    rug=False,
    lines=None,
    compact=False,
    combined=False,
    legend=False,
    plot_kwargs=None,
    fill_kwargs=None,
    rug_kwargs=None,
    hist_kwargs=None,
    trace_kwargs=None,
    backend_kwargs=None,
    show=True,
):
    """Plot distribution (histogram or kernel density estimates) and sampled values.

    Parameters
    ----------
    data : obj
        Any object that can be converted to an az.InferenceData object
        Refer to documentation of az.convert_to_dataset for details
    var_names : string, or list of strings
        One or more variables to be plotted.
    coords : mapping, optional
        Coordinates of var_names to be plotted. Passed to `Dataset.sel`
    divergences : {"bottom", "top", None, False}
        NOT IMPLEMENTED
        Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
    figsize : figure size tuple
        If None, size is (12, variables * 2)
    rug : bool
        If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous
        variables.
    lines : tuple
        Tuple of (var_name, {'coord': selection}, [line, positions]) to be overplotted as
        vertical lines on the density and horizontal lines on the trace.
    compact : bool
        Plot multidimensional variables in a single plot.
    combined : bool
        Flag for combining multiple chains into a single line. If False (default), chains will be
        plotted separately.
    legend : bool
        Add a legend to the figure with the chain color code.
    plot_kwargs : dict
        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
    fill_kwargs : dict
        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
    rug_kwargs : dict
        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
    hist_kwargs : dict
        Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables.
    trace_kwargs : dict
        Extra keyword arguments passed to `bokeh.plotting.lines`
    backend_kwargs : dict
        Extra keyword arguments passed to `bokeh.plotting.figure`
    show : bool
        Call `bokeh.plotting.show` for gridded plots `bokeh.layouts.gridplot(axes.tolist())`
    Returns
    -------
    ndarray
        axes (bokeh figures)


    Examples
    --------
    Plot a subset variables

    .. plot::
        :context: close-figs

        >>> import arviz as az
        >>> data = az.load_arviz_data('non_centered_eight')
        >>> coords = {'school': ['Choate', 'Lawrenceville']}
        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords)

    Show all dimensions of multidimensional variables in the same plot

    .. plot::
        :context: close-figs

        >>> az.plot_trace(data, compact=True)

    Combine all chains into one distribution

    .. plot::
        :context: close-figs

        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, combined=True)


    Plot reference lines against distribution and trace

    .. plot::
        :context: close-figs

        >>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)

    """
    if divergences:
        try:
            divergence_data = convert_to_dataset(data, group="sample_stats").diverging
        except (ValueError, AttributeError):  # No sample_stats, or no `.diverging`
            divergences = False

    if coords is None:
        coords = {}

    data = get_coords(convert_to_dataset(data, group="posterior"), coords)
    var_names = _var_names(var_names, data)

    if divergences:
        divergence_data = get_coords(
            divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")}
        )

    if lines is None:
        lines = ()

    num_colors = len(data.chain) + 1 if combined else len(data.chain)
    colors = [
        prop
        for _, prop in zip(
            range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
        )
    ]

    if compact:
        skip_dims = set(data.dims) - {"chain", "draw"}
    else:
        skip_dims = set()

    plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims))
    max_plots = rcParams["plot.max_subplots"]
    max_plots = len(plotters) if max_plots is None else max_plots
    if len(plotters) > max_plots:
        warnings.warn(
            "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
            "of variables to plot ({len_plotters}), generating only {max_plots} "
            "plots".format(max_plots=max_plots, len_plotters=len(plotters)),
            SyntaxWarning,
        )
        plotters = plotters[:max_plots]

    if figsize is None:
        figsize = (12, len(plotters) * 2)

    if trace_kwargs is None:
        trace_kwargs = {}

    trace_kwargs.setdefault("alpha", 0.35)

    if hist_kwargs is None:
        hist_kwargs = {}
    if plot_kwargs is None:
        plot_kwargs = {}
    if fill_kwargs is None:
        fill_kwargs = {}
    if rug_kwargs is None:
        rug_kwargs = {}

    hist_kwargs.setdefault("alpha", 0.35)

    figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2)
    figsize = int(figsize[0] * 90 // 2), int(figsize[1] * 90 // len(plotters))

    trace_kwargs.setdefault("line_width", linewidth)
    plot_kwargs.setdefault("line_width", linewidth)

    if backend_kwargs is None:
        backend_kwargs = dict()

    backend_kwargs.setdefault(
        "tools",
        ("pan,wheel_zoom,box_zoom," "lasso_select,poly_select," "undo,redo,reset,save,hover"),
    )
    backend_kwargs.setdefault("output_backend", "webgl")
    backend_kwargs.setdefault("height", figsize[1])
    backend_kwargs.setdefault("width", figsize[0])

    axes = []
    for i in range(len(plotters)):
        if i != 0:
            _axes = [
                bkp.figure(**backend_kwargs),
                bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs),
            ]
        else:
            _axes = [bkp.figure(**backend_kwargs), bkp.figure(**backend_kwargs)]
        axes.append(_axes)

    axes = np.array(axes)

    cds_data = {}
    cds_var_groups = {}
    draw_name = "draw"

    for var_name, selection, value in list(
        xarray_var_iter(data, var_names=var_names, combined=True)
    ):
        if selection:
            cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format(
                var_name,
                "_".join(
                    str(item)
                    for key, value in selection.items()
                    for item in (
                        [key, value]
                        if (isinstance(value, str) or not isinstance(value, Iterable))
                        else [key, *value]
                    )
                ),
            )
        else:
            cds_name = var_name

        if var_name not in cds_var_groups:
            cds_var_groups[var_name] = []
        cds_var_groups[var_name].append(cds_name)

        for chain_idx, _ in enumerate(data.chain.values):
            if chain_idx not in cds_data:
                cds_data[chain_idx] = {}
            _data = value[chain_idx]
            cds_data[chain_idx][cds_name] = _data

    while any(key == draw_name for key in cds_data[0]):
        draw_name += "w"

    for chain_idx in cds_data:
        cds_data[chain_idx][draw_name] = data.draw.values

    cds_data = {chain_idx: ColumnDataSource(cds) for chain_idx, cds in cds_data.items()}

    for idx, (var_name, selection, value) in enumerate(plotters):
        value = np.atleast_2d(value)

        if len(value.shape) == 2:
            y_name = (
                var_name
                if not selection
                else "{}_ARVIZ_CDS_SELECTION_{}".format(
                    var_name,
                    "_".join(
                        str(item)
                        for key, value in selection.items()
                        for item in (
                            (key, value)
                            if (isinstance(value, str) or not isinstance(value, Iterable))
                            else (key, *value)
                        )
                    ),
                )
            )
            if rug:
                rug_kwargs["y"] = y_name
            _plot_chains_bokeh(
                ax_density=axes[idx, 0],
                ax_trace=axes[idx, 1],
                data=cds_data,
                x_name=draw_name,
                y_name=y_name,
                colors=colors,
                combined=combined,
                rug=rug,
                legend=legend,
                trace_kwargs=trace_kwargs,
                hist_kwargs=hist_kwargs,
                plot_kwargs=plot_kwargs,
                fill_kwargs=fill_kwargs,
                rug_kwargs=rug_kwargs,
            )
        else:
            for y_name in cds_var_groups[var_name]:
                if rug:
                    rug_kwargs["y"] = y_name
                _plot_chains_bokeh(
                    ax_density=axes[idx, 0],
                    ax_trace=axes[idx, 1],
                    data=cds_data,
                    x_name=draw_name,
                    y_name=y_name,
                    colors=colors,
                    combined=combined,
                    rug=rug,
                    legend=legend,
                    trace_kwargs=trace_kwargs,
                    hist_kwargs=hist_kwargs,
                    plot_kwargs=plot_kwargs,
                    fill_kwargs=fill_kwargs,
                    rug_kwargs=rug_kwargs,
                )

        for col in (0, 1):
            _title = Title()
            _title.text = make_label(var_name, selection)
            axes[idx, col].title = _title

        for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection):
            if isinstance(vlines, (float, int)):
                line_values = [vlines]
            else:
                line_values = np.atleast_1d(vlines).ravel()

            for line_value in line_values:
                vline = Span(
                    location=line_value,
                    dimension="height",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=0.75,
                )
                hline = Span(
                    location=line_value,
                    dimension="width",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=trace_kwargs["alpha"],
                )

                axes[idx, 0].renderers.append(vline)
                axes[idx, 1].renderers.append(hline)

        if legend:
            for col in (0, 1):
                axes[idx, col].legend.location = "top_left"
                axes[idx, col].legend.click_policy = "hide"
        else:
            for col in (0, 1):
                if axes[idx, col].legend:
                    axes[idx, col].legend.visible = False

        if divergences:
            div_density_kwargs = {}
            div_density_kwargs.setdefault("size", 14)
            div_density_kwargs.setdefault("line_color", "red")
            div_density_kwargs.setdefault("line_width", 2)
            div_density_kwargs.setdefault("line_alpha", 0.50)
            div_density_kwargs.setdefault("angle", np.pi / 2)

            div_trace_kwargs = {}
            div_trace_kwargs.setdefault("size", 14)
            div_trace_kwargs.setdefault("line_color", "red")
            div_trace_kwargs.setdefault("line_width", 2)
            div_trace_kwargs.setdefault("line_alpha", 0.50)
            div_trace_kwargs.setdefault("angle", np.pi / 2)

            div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims}
            divs = divergence_data.sel(**div_selection).values
            divs = np.atleast_2d(divs)

            for chain, chain_divs in enumerate(divs):
                div_idxs = np.arange(len(chain_divs))[chain_divs]
                if div_idxs.size > 0:
                    values = value[chain, div_idxs]
                    tmp_cds = ColumnDataSource({"y": values, "x": div_idxs})
                    if divergences == "top":
                        y_div_trace = value.max()
                    else:
                        y_div_trace = value.min()
                    glyph_density = Dash(x="y", y=0.0, **div_density_kwargs)
                    glyph_trace = Dash(x="x", y=y_div_trace, **div_trace_kwargs)

                    axes[idx, 0].add_glyph(tmp_cds, glyph_density)
                    axes[idx, 1].add_glyph(tmp_cds, glyph_trace)

    if show:
        grid = gridplot([list(item) for item in axes], toolbar_location="above")
        bkp.show(grid)

    return axes
コード例 #6
0
              line_color="#F0027F",
              fill_color=None,
              line_width=2)),
    ("x",
     X(x="x",
       y="y",
       size="sizes",
       line_color="thistle",
       fill_color=None,
       line_width=2)),
    ("hex", Hex(x="x", y="y", size="sizes", line_color="#99D594",
                line_width=2)),
    ("dash",
     Dash(x="x",
          y="y",
          size="sizes",
          angle=0.5,
          line_color="#386CB0",
          line_width=1)),
]


def make_tab(title, glyph):
    plot = Plot()
    plot.title.text = title

    plot.add_glyph(source, glyph)

    xaxis = LinearAxis()
    plot.add_layout(xaxis, 'below')

    yaxis = LinearAxis()
コード例 #7
0
def plot_ess(
    ax,
    plotters,
    xdata,
    ess_tail_dataset,
    mean_ess,
    sd_ess,
    idata,
    data,
    text_x,
    text_va,
    kind,
    extra_methods,
    rows,
    cols,
    figsize,
    kwargs,
    extra_kwargs,
    text_kwargs,
    _linewidth,
    _markersize,
    n_samples,
    relative,
    min_ess,
    xt_labelsize,
    titlesize,
    ax_labelsize,
    ylabel,
    rug,
    rug_kind,
    rug_kwargs,
    hline_kwargs,
    backend_kwargs,
    show,
):
    """Bokeh essplot."""
    if backend_kwargs is None:
        backend_kwargs = {}

    backend_kwargs = {
        **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"), ),
        **backend_kwargs,
    }
    if ax is None:
        _, ax = _create_axes_grid(
            len(plotters),
            rows,
            cols,
            figsize=figsize,
            squeeze=False,
            constrained_layout=True,
            backend="bokeh",
            backend_kwargs=backend_kwargs,
        )
    else:
        ax = np.atleast_2d(ax)

    for (var_name, selection,
         x), ax_ in zip(plotters,
                        (item for item in ax.flatten() if item is not None)):
        bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
        if kind == "evolution":
            bulk_line = ax_.line(np.asarray(xdata), np.asarray(x))
            ess_tail = ess_tail_dataset[var_name].sel(**selection)
            tail_points = ax_.line(np.asarray(xdata),
                                   np.asarray(ess_tail),
                                   color="orange")
            tail_line = ax_.circle(np.asarray(xdata),
                                   np.asarray(ess_tail),
                                   size=6,
                                   color="orange")
        elif rug:
            if rug_kwargs is None:
                rug_kwargs = {}
            if not hasattr(idata, "sample_stats"):
                raise ValueError(
                    "InferenceData object must contain sample_stats for rug plot"
                )
            if not hasattr(idata.sample_stats, rug_kind):
                raise ValueError(
                    "InferenceData does not contain {} data".format(rug_kind))

            rug_kwargs.setdefault("space", 0.1)
            _rug_kwargs = {}
            _rug_kwargs.setdefault("size", 8)
            _rug_kwargs.setdefault("line_color",
                                   rug_kwargs.get("line_color", "black"))
            _rug_kwargs.setdefault("line_width", 1)
            _rug_kwargs.setdefault("line_alpha", 0.35)
            _rug_kwargs.setdefault("angle", np.pi / 2)

            values = data[var_name].sel(**selection).values.flatten()
            mask = idata.sample_stats[rug_kind].values.flatten()
            values = rankdata(values)[mask]
            rug_space = np.max(x) * rug_kwargs.pop("space")
            rug_x, rug_y = values / (len(mask) -
                                     1), np.zeros_like(values) - rug_space

            glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs)
            cds_rug = ColumnDataSource({
                "rug_x": np.asarray(rug_x),
                "rug_y": np.asarray(rug_y)
            })
            ax_.add_glyph(cds_rug, glyph)

            hline = Span(
                location=0,
                dimension="width",
                line_color="black",
                line_width=_linewidth,
                line_alpha=0.7,
            )

            ax_.renderers.append(hline)

        if extra_methods:
            mean_ess_i = mean_ess[var_name].sel(**selection).values.item()
            sd_ess_i = sd_ess[var_name].sel(**selection).values.item()

            hline = Span(
                location=mean_ess_i,
                dimension="width",
                line_color="black",
                line_width=2,
                line_dash="dashed",
                line_alpha=1.0,
            )

            ax_.renderers.append(hline)

            hline = Span(
                location=sd_ess_i,
                dimension="width",
                line_color="black",
                line_width=1,
                line_dash="dashed",
                line_alpha=1.0,
            )

            ax_.renderers.append(hline)

        hline = Span(
            location=400 / n_samples if relative else min_ess,
            dimension="width",
            line_color="red",
            line_width=3,
            line_dash="dashed",
            line_alpha=1.0,
        )

        ax_.renderers.append(hline)

        if kind == "evolution":
            legend = Legend(
                items=[("bulk", [bulk_points, bulk_line]),
                       ("tail", [tail_line, tail_points])],
                location="center_right",
                orientation="horizontal",
            )
            ax_.add_layout(legend, "above")
            ax_.legend.click_policy = "hide"

        title = Title()
        title.text = make_label(var_name, selection)
        ax_.title = title

        ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile"
        ax_.yaxis.axis_label = ylabel.format(
            "Relative ESS" if relative else "ESS")

    show_layout(ax, show)

    return ax
コード例 #8
0
def whisker_sentiment(df, calc_attr):
    # filters for not negative values to better analyze the mean trends
    df = df[df[calc_attr] != 0]
    df = df.sort_values(by=['day_range']).copy()

    calc_attr_capitalized = calc_attr.capitalize()
    # create figure
    p = figure(plot_width=400,
               plot_height=350,
               y_axis_label=calc_attr_capitalized)

    base, lower, upper, mean = [], [], [], []

    for i, date in enumerate(list(df.day_range.unique())):
        day_pol = df[df['day_range'] == date][calc_attr]
        pol_mean = day_pol.mean()
        pol_std = day_pol.std()
        lower.append(pol_mean - pol_std)
        upper.append(pol_mean + pol_std)
        base.append(date)
        mean.append(pol_mean)

    source_error = ColumnDataSource(data={
        'base': base,
        'lower': lower,
        'upper': upper
    })
    source_mean = ColumnDataSource(data={'x': base, 'y': mean})
    # plot scatter plots for each time point in series
    for i, date in enumerate(list(df.day_range.unique())):
        y = df[df['day_range'] == date][calc_attr]
        p.circle(x=date, y=y, color='teal', size=3, alpha=0.1)

    p.add_layout(
        Whisker(source=source_error,
                base="base",
                upper="upper",
                lower="lower",
                line_color='pink',
                line_width=5))

    # indicator for mean values as dash added to plot
    dash = Dash(x="x",
                y="y",
                size=10,
                line_color="black",
                line_width=1,
                fill_color=None)
    p.add_glyph(source_mean, dash)

    # adjust min and max of x axis and distribute major axis ticks over date
    # ranges
    p.x_range = Range1d(0, max(base) + 1)
    p.xaxis.ticker = np.linspace(min(base), max(base), 6, dtype='int').tolist()
    # call create tick date dict function to convert a list of day sequence of
    # the data series to a dict containing keys as list of days and values
    # the according dates. This dict format is required for FuncTickFormatter
    # JS code to change the x labels to according dates
    date_tick_dict = create_tick_date_dict(df)
    p.xaxis.formatter = FuncTickFormatter(code="""
    var mapping = %s;
    return mapping[tick];
    """ % date_tick_dict)

    # configure visual properties on a plot's title attribute
    p.title.text = f"Mean {calc_attr_capitalized} Over Time"
    p.title.align = "center"
    p.title.text_color = text_color
    p.title.text_font_size = title_font_size
    p.title.text_font = font
    p.title.text_font_style = font_style

    # configure axis labels
    p.axis.axis_label_text_font_style = font_style
    p.axis.axis_label_text_font_size = axis_label_font_size
    p.axis.axis_label_text_font = font
    p.axis.axis_label_text_color = text_color
    p.axis.major_label_text_font_size = '10pt'
    p.axis.major_label_text_font = font
    p.axis.major_label_text_color = text_color

    p.axis.axis_line_color = text_color

    # turning off minor labels
    p.xaxis.minor_tick_line_color = None
    p.yaxis.minor_tick_line_color = None
    return p
コード例 #9
0
def _plot_kde_bokeh(
    density,
    lower,
    upper,
    density_q,
    xmin,
    xmax,
    ymin,
    ymax,
    gridsize,
    values,
    values2=None,
    rug=False,
    label=None,
    quantiles=None,
    rotated=False,
    contour=True,
    fill_last=True,
    plot_kwargs=None,
    fill_kwargs=None,
    rug_kwargs=None,
    contour_kwargs=None,
    contourf_kwargs=None,
    pcolormesh_kwargs=None,
    ax=None,
    legend=True,
    show=True,
):
    if ax is None:
        tools = rcParams["plot.bokeh.tools"]
        output_backend = rcParams["plot.bokeh.output_backend"]
        ax = bkp.figure(
            width=rcParams["plot.bokeh.figure.width"],
            height=rcParams["plot.bokeh.figure.height"],
            output_backend=output_backend,
            tools=tools,
        )

    if legend and label is not None:
        plot_kwargs["legend_label"] = label

    if values2 is None:
        if plot_kwargs is None:
            plot_kwargs = {}
        plot_kwargs.setdefault(
            "line_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0])

        if fill_kwargs is None:
            fill_kwargs = {}

        fill_kwargs.setdefault(
            "fill_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0])

        if rug:
            if rug_kwargs is None:
                rug_kwargs = {}

            rug_kwargs = rug_kwargs.copy()
            if "cds" in rug_kwargs:
                cds_rug = rug_kwargs.pop("cds")
                rug_varname = rug_kwargs.pop("y", "y")
            else:
                rug_varname = "y"
                cds_rug = ColumnDataSource({rug_varname: np.asarray(values)})

            rug_kwargs.setdefault("size", 8)
            rug_kwargs.setdefault("line_color", plot_kwargs["line_color"])
            rug_kwargs.setdefault("line_width", 1)
            rug_kwargs.setdefault("line_alpha", 0.35)
            if not rotated:
                rug_kwargs.setdefault("angle", np.pi / 2)
            if isinstance(cds_rug, dict):
                for _cds_rug in cds_rug.values():
                    if not rotated:
                        glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs)
                    else:
                        glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs)
                    ax.add_glyph(_cds_rug, glyph)
            else:
                if not rotated:
                    glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs)
                else:
                    glyph = Dash(x=0.0, y=rug_varname, **rug_kwargs)
                ax.add_glyph(cds_rug, glyph)

        x = np.linspace(lower, upper, len(density))

        if quantiles is not None:
            fill_kwargs.setdefault("fill_alpha", 0.75)
            fill_kwargs.setdefault("line_color", None)

            quantiles = sorted(np.clip(quantiles, 0, 1))
            if quantiles[0] != 0:
                quantiles = [0] + quantiles
            if quantiles[-1] != 1:
                quantiles = quantiles + [1]

            for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]):
                idx = (density_q > quant_0) & (density_q < quant_1)
                if idx.sum():
                    patch_x = np.concatenate(
                        (x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]]))
                    patch_y = np.concatenate(
                        (np.zeros_like(density[idx]), [density[idx][-1]],
                         density[idx][::-1], [0]))
                    if not rotated:
                        ax.patch(patch_x, patch_y, **fill_kwargs)
                    else:
                        ax.patch(patch_y, patch_x, **fill_kwargs)
        else:
            if fill_kwargs.get("fill_alpha", False):
                patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]]))
                patch_y = np.concatenate((np.zeros_like(density),
                                          [density[-1]], density[::-1], [0]))
                if not rotated:
                    ax.patch(patch_x, patch_y, **fill_kwargs)
                else:
                    ax.patch(patch_y, patch_x, **fill_kwargs)

            if not rotated:
                ax.line(x, density, **plot_kwargs)
            else:
                ax.line(density, x, **plot_kwargs)

    else:
        if contour_kwargs is None:
            contour_kwargs = {}
        if contourf_kwargs is None:
            contourf_kwargs = {}
        if pcolormesh_kwargs is None:
            pcolormesh_kwargs = {}

        g_s = complex(gridsize[0])
        x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s]

        if contour:

            scaled_density, *scaled_density_args = _scale_axis(density)

            contour_generator = _contour.QuadContourGenerator(
                x_x, y_y, scaled_density, None, True, 0)

            if "levels" in contour_kwargs:
                levels = contour_kwargs.get("levels")
            elif "levels" in contourf_kwargs:
                levels = contourf_kwargs.get("levels")
            else:
                levels = 11

            if isinstance(levels, Integral):
                levels_scaled = np.linspace(0, 1, levels)
                levels = _rescale_axis(levels_scaled, scaled_density_args)
            else:
                levels_scaled_nonclip = _scale_axis(np.asarray(levels),
                                                    scaled_density_args)
                levels_scaled = np.clip(levels_scaled_nonclip, 0, 1)

            cmap = contourf_kwargs.pop("cmap", "viridis")
            if isinstance(cmap, str):
                cmap = get_cmap(cmap)
            if isinstance(cmap, Callable):
                colors = [
                    rgb2hex(item)
                    for item in cmap(np.linspace(0, 1,
                                                 len(levels_scaled) + 1))
                ]
            else:
                colors = cmap

            contour_kwargs.update(contourf_kwargs)
            contour_kwargs.setdefault("line_color", "black")
            contour_kwargs.setdefault("line_alpha", 0.25)
            contour_kwargs.setdefault("fill_alpha", 1)

            for i, (level, level_upper, color) in enumerate(
                    zip(levels_scaled[:-1], levels_scaled[1:], colors[1:])):
                if not fill_last and (i == 0):
                    continue
                vertices, _ = contour_generator.create_filled_contour(
                    level, level_upper)
                for seg in vertices:
                    ax.patch(*seg.T, fill_color=color, **contour_kwargs)

            if fill_last:
                ax.background_fill_color = colors[0]

            ax.xgrid.grid_line_color = None
            ax.ygrid.grid_line_color = None

            ax.x_range = Range1d(xmin, xmax)
            ax.y_range = Range1d(ymin, ymax)

        else:

            cmap = pcolormesh_kwargs.pop("cmap", "viridis")
            if isinstance(cmap, str):
                cmap = get_cmap(cmap)
            if isinstance(cmap, Callable):
                colors = [
                    rgb2hex(item) for item in cmap(np.linspace(0, 1, 256))
                ]
            else:
                colors = cmap

            ax.image(image=[density.T],
                     x=xmin,
                     y=ymin,
                     dw=(xmax - xmin) / density.shape[0],
                     dh=(ymax - ymin) / density.shape[1],
                     palette=colors,
                     **pcolormesh_kwargs)
            ax.x_range.range_padding = ax.y_range.range_padding = 0

    if show:
        bkp.show(ax, toolbar_location="above")
    return ax
コード例 #10
0
ファイル: bokeh_essplot.py プロジェクト: bbw7561135/arviz
def _plot_ess(
    ax,
    plotters,
    xdata,
    ess_tail_dataset,
    mean_ess,
    sd_ess,
    idata,
    data,
    text_x,
    text_va,
    kind,
    extra_methods,
    rows,
    cols,
    figsize,
    kwargs,
    extra_kwargs,
    text_kwargs,
    _linewidth,
    _markersize,
    n_samples,
    relative,
    min_ess,
    xt_labelsize,
    titlesize,
    ax_labelsize,
    ylabel,
    rug,
    rug_kind,
    rug_kwargs,
    hline_kwargs,
    show,
):
    if ax is None:
        _, ax = _create_axes_grid(
            len(plotters),
            rows,
            cols,
            figsize=figsize,
            squeeze=False,
            constrained_layout=True,
            backend="bokeh",
        )
    for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)):
        ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
        if kind == "evolution":
            ax_.line(np.asarray(xdata), np.asarray(x), legend_label="bulk")
            ess_tail = ess_tail_dataset[var_name].sel(**selection)
            ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange", legend_label="tail")
            ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange")
        elif rug:
            if rug_kwargs is None:
                rug_kwargs = {}
            if not hasattr(idata, "sample_stats"):
                raise ValueError("InferenceData object must contain sample_stats for rug plot")
            if not hasattr(idata.sample_stats, rug_kind):
                raise ValueError("InferenceData does not contain {} data".format(rug_kind))

            rug_kwargs.setdefault("space", 0.1)
            _rug_kwargs = {}
            _rug_kwargs.setdefault("size", 8)
            _rug_kwargs.setdefault("line_color", rug_kwargs.get("line_color", "black"))
            _rug_kwargs.setdefault("line_width", 1)
            _rug_kwargs.setdefault("line_alpha", 0.35)
            _rug_kwargs.setdefault("angle", np.pi / 2)

            values = data[var_name].sel(**selection).values.flatten()
            mask = idata.sample_stats[rug_kind].values.flatten()
            values = rankdata(values)[mask]
            rug_space = np.max(x) * rug_kwargs.pop("space")
            rug_x, rug_y = values / (len(mask) - 1), np.zeros_like(values) - rug_space

            glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs)
            cds_rug = ColumnDataSource({"rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y)})
            ax_.add_glyph(cds_rug, glyph)

            hline = Span(
                location=0,
                dimension="width",
                line_color="black",
                line_width=_linewidth,
                line_alpha=0.7,
            )

            ax_.renderers.append(hline)

        if extra_methods:
            mean_ess_i = mean_ess[var_name].sel(**selection).values.item()
            sd_ess_i = sd_ess[var_name].sel(**selection).values.item()

            hline = Span(
                location=mean_ess_i,
                dimension="width",
                line_color="black",
                line_width=2,
                line_dash="dashed",
                line_alpha=1.0,
            )

            ax_.renderers.append(hline)

            hline = Span(
                location=sd_ess_i,
                dimension="width",
                line_color="black",
                line_width=1,
                line_dash="dashed",
                line_alpha=1.0,
            )

            ax_.renderers.append(hline)

        hline = Span(
            location=400 / n_samples if relative else min_ess,
            dimension="width",
            line_color="red",
            line_width=3,
            line_dash="dashed",
            line_alpha=1.0,
        )

        ax_.renderers.append(hline)

        title = Title()
        title.text = make_label(var_name, selection)
        ax_.title = title

        ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile"
        ax_.yaxis.axis_label = ylabel.format("Relative ESS" if relative else "ESS")

    if show:
        grid = gridplot([list(item) for item in ax], toolbar_location="above")
        bkp.show(grid)

    return ax
コード例 #11
0
def _plot_kde_bokeh(
    density,
    lower,
    upper,
    density_q,
    xmin,
    xmax,
    ymin,
    ymax,
    gridsize,
    values,
    values2=None,
    rug=False,
    label=None,
    bw=4.5,
    quantiles=None,
    rotated=False,
    contour=True,
    fill_last=True,
    textsize=None,
    plot_kwargs=None,
    fill_kwargs=None,
    rug_kwargs=None,
    contour_kwargs=None,
    contourf_kwargs=None,
    pcolormesh_kwargs=None,
    ax=None,
    legend=True,
    show=True,
):
    if ax is None:
        ax = bkp.figure(sizing_mode="stretch_both")

    if legend and label is not None:
        plot_kwargs["legend_label"] = label

    if values2 is None:
        if plot_kwargs is None:
            plot_kwargs = {}
        plot_kwargs.setdefault(
            "line_color", plt.rcParams["axes.prop_cycle"].by_key()["color"][0])

        default_color = plot_kwargs.get("color")

        if fill_kwargs is None:
            fill_kwargs = {}

        fill_kwargs.setdefault("color", default_color)

        if rug:
            if rug_kwargs is None:
                rug_kwargs = {}

            rug_kwargs = rug_kwargs.copy()
            if "cds" in rug_kwargs:
                cds_rug = rug_kwargs.pop("cds")
                rug_varname = rug_kwargs.pop("y", "y")
            else:
                rug_varname = "y"
                cds_rug = ColumnDataSource({rug_varname: values})

            rug_kwargs.setdefault("size", 8)
            rug_kwargs.setdefault("line_color", plot_kwargs["line_color"])
            rug_kwargs.setdefault("line_width", 1)
            rug_kwargs.setdefault("line_alpha", 0.35)
            rug_kwargs.setdefault("angle", np.pi / 2)
            if isinstance(cds_rug, dict):
                for _cds_rug in cds_rug.values():
                    glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs)
                    ax.add_glyph(_cds_rug, glyph)
            else:
                glyph = Dash(x=rug_varname, y=0.0, **rug_kwargs)
                ax.add_glyph(cds_rug, glyph)

        x = np.linspace(lower, upper, len(density))
        ax.line(x, density, **plot_kwargs)
    else:
        # todo
        raise NotImplementedError("Use matplotlib backend")

    if show:
        bkp.show(ax)
    return ax
コード例 #12
0
ファイル: mcseplot.py プロジェクト: utkarsh-maheshwari/arviz
def plot_mcse(
    ax,
    plotters,
    length_plotters,
    rows,
    cols,
    figsize,
    errorbar,
    rug,
    data,
    probs,
    kwargs,  # pylint: disable=unused-argument
    extra_methods,
    mean_mcse,
    sd_mcse,
    textsize,
    labeller,
    text_kwargs,  # pylint: disable=unused-argument
    rug_kwargs,
    extra_kwargs,
    idata,
    rug_kind,
    backend_kwargs,
    show,
):
    """Bokeh mcse plot."""
    if backend_kwargs is None:
        backend_kwargs = {}

    backend_kwargs = {
        **backend_kwarg_defaults(),
        **backend_kwargs,
    }

    (figsize, *_, _linewidth, _markersize) = _scale_fig_size(figsize, textsize, rows, cols)

    extra_kwargs = {} if extra_kwargs is None else extra_kwargs
    extra_kwargs.setdefault("linewidth", _linewidth / 2)
    extra_kwargs.setdefault("color", "black")
    extra_kwargs.setdefault("alpha", 0.5)

    if ax is None:
        ax = create_axes_grid(
            length_plotters,
            rows,
            cols,
            figsize=figsize,
            backend_kwargs=backend_kwargs,
        )
    else:
        ax = np.atleast_2d(ax)

    for (var_name, selection, isel, x), ax_ in zip(
        plotters, (item for item in ax.flatten() if item is not None)
    ):
        if errorbar or rug:
            values = data[var_name].sel(**selection).values.flatten()
        if errorbar:
            quantile_values = _quantile(values, probs)
            ax_.dash(probs, quantile_values)
            ax_.multi_line(
                list(zip(probs, probs)),
                [(quant - err, quant + err) for quant, err in zip(quantile_values, x)],
            )
        else:
            ax_.circle(probs, x)
            if extra_methods:
                mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item()
                sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
                hline_mean = Span(
                    location=mean_mcse_i,
                    dimension="width",
                    line_color=extra_kwargs["color"],
                    line_width=extra_kwargs["linewidth"] * 2,
                    line_alpha=extra_kwargs["alpha"],
                )

                ax_.renderers.append(hline_mean)

                hline_sd = Span(
                    location=sd_mcse_i,
                    dimension="width",
                    line_color="black",
                    line_width=extra_kwargs["linewidth"],
                    line_alpha=extra_kwargs["alpha"],
                )

                ax_.renderers.append(hline_sd)

        if rug:
            if rug_kwargs is None:
                rug_kwargs = {}
            if not hasattr(idata, "sample_stats"):
                raise ValueError("InferenceData object must contain sample_stats for rug plot")
            if not hasattr(idata.sample_stats, rug_kind):
                raise ValueError("InferenceData does not contain {} data".format(rug_kind))
            rug_kwargs.setdefault("space", 0.1)

            _rug_kwargs = {}
            _rug_kwargs.setdefault("size", 8)
            _rug_kwargs.setdefault("line_color", rug_kwargs.get("line_color", "black"))
            _rug_kwargs.setdefault("line_width", 1)
            _rug_kwargs.setdefault("line_alpha", 0.35)
            _rug_kwargs.setdefault("angle", np.pi / 2)

            mask = idata.sample_stats[rug_kind].values.flatten()
            values = rankdata(values, method="average")[mask]
            if errorbar:
                rug_x, rug_y = (
                    values / (len(mask) - 1),
                    np.full_like(
                        values,
                        min(
                            0,
                            min(quantile_values)
                            - (max(quantile_values) - min(quantile_values)) * 0.05,
                        ),
                    ),
                )

                hline = Span(
                    location=min(
                        0,
                        min(quantile_values) - (max(quantile_values) - min(quantile_values)) * 0.05,
                    ),
                    dimension="width",
                    line_color="black",
                    line_width=_linewidth,
                    line_alpha=0.7,
                )

            else:
                rug_x, rug_y = (
                    values / (len(mask) - 1),
                    np.full_like(
                        values,
                        0,
                    ),
                )

                hline = Span(
                    location=0,
                    dimension="width",
                    line_color="black",
                    line_width=_linewidth,
                    line_alpha=0.7,
                )

            ax_.renderers.append(hline)

            glyph = Dash(x="rug_x", y="rug_y", **_rug_kwargs)
            cds_rug = ColumnDataSource({"rug_x": np.asarray(rug_x), "rug_y": np.asarray(rug_y)})
            ax_.add_glyph(cds_rug, glyph)

        title = Title()
        title.text = labeller.make_label_vert(var_name, selection, isel)
        ax_.title = title

        ax_.xaxis.axis_label = "Quantile"
        ax_.yaxis.axis_label = (
            r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles"
        )

        if not errorbar:
            ax_.y_range._property_values["start"] = -0.05  # pylint: disable=protected-access
            ax_.y_range._property_values["end"] = 1  # pylint: disable=protected-access

    show_layout(ax, show)

    return ax
コード例 #13
0
def _plot_trace_bokeh(
    data,
    var_names=None,
    coords=None,
    divergences="bottom",
    figsize=None,
    rug=False,
    lines=None,
    compact=False,
    combined=False,
    legend=False,
    plot_kwargs=None,
    fill_kwargs=None,
    rug_kwargs=None,
    hist_kwargs=None,
    trace_kwargs=None,
    backend_kwargs=None,
    show=True,
):
    if divergences:
        try:
            divergence_data = convert_to_dataset(
                data, group="sample_stats").diverging
        except (ValueError,
                AttributeError):  # No sample_stats, or no `.diverging`
            divergences = False

    if coords is None:
        coords = {}

    data = get_coords(convert_to_dataset(data, group="posterior"), coords)
    var_names = _var_names(var_names, data)

    if divergences:
        divergence_data = get_coords(
            divergence_data,
            {k: v
             for k, v in coords.items() if k in ("chain", "draw")})

    if lines is None:
        lines = ()

    num_colors = len(data.chain) + 1 if combined else len(data.chain)
    colors = [
        prop for _, prop in zip(
            range(num_colors),
            cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]))
    ]

    if compact:
        skip_dims = set(data.dims) - {"chain", "draw"}
    else:
        skip_dims = set()

    plotters = list(
        xarray_var_iter(data,
                        var_names=var_names,
                        combined=True,
                        skip_dims=skip_dims))
    max_plots = rcParams["plot.max_subplots"]
    max_plots = len(plotters) if max_plots is None else max_plots
    if len(plotters) > max_plots:
        warnings.warn(
            "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
            "of variables to plot ({len_plotters}), generating only {max_plots} "
            "plots".format(max_plots=max_plots, len_plotters=len(plotters)),
            SyntaxWarning,
        )
        plotters = plotters[:max_plots]

    if figsize is None:
        figsize = (12, len(plotters) * 2)

    if trace_kwargs is None:
        trace_kwargs = {}

    trace_kwargs.setdefault("alpha", 0.35)

    if hist_kwargs is None:
        hist_kwargs = {}
    if plot_kwargs is None:
        plot_kwargs = {}
    if fill_kwargs is None:
        fill_kwargs = {}
    if rug_kwargs is None:
        rug_kwargs = {}

    hist_kwargs.setdefault("alpha", 0.35)

    figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize,
                                                     10,
                                                     rows=len(plotters),
                                                     cols=2)

    trace_kwargs.setdefault("line_width", linewidth)
    plot_kwargs.setdefault("line_width", linewidth)

    if backend_kwargs is None:
        backend_kwargs = dict()

    backend_kwargs.setdefault("tools", rcParams["plot.bokeh.tools"])
    backend_kwargs.setdefault("output_backend",
                              rcParams["plot.bokeh.output_backend"])
    backend_kwargs.setdefault(
        "height",
        int(figsize[1] * rcParams["plot.bokeh.figure.dpi"] // len(plotters)))
    backend_kwargs.setdefault(
        "width", int(figsize[0] * rcParams["plot.bokeh.figure.dpi"] // 2))

    axes = []
    for i in range(len(plotters)):
        if i != 0:
            _axes = [
                bkp.figure(**backend_kwargs),
                bkp.figure(x_range=axes[0][1].x_range, **backend_kwargs),
            ]
        else:
            _axes = [
                bkp.figure(**backend_kwargs),
                bkp.figure(**backend_kwargs)
            ]
        axes.append(_axes)

    axes = np.array(axes)

    cds_data = {}
    cds_var_groups = {}
    draw_name = "draw"

    for var_name, selection, value in list(
            xarray_var_iter(data, var_names=var_names, combined=True)):
        if selection:
            cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format(
                var_name,
                "_".join(
                    str(item) for key, value in selection.items()
                    for item in ([key, value] if (
                        isinstance(value, str) or
                        not isinstance(value, Iterable)) else [key, *value])),
            )
        else:
            cds_name = var_name

        if var_name not in cds_var_groups:
            cds_var_groups[var_name] = []
        cds_var_groups[var_name].append(cds_name)

        for chain_idx, _ in enumerate(data.chain.values):
            if chain_idx not in cds_data:
                cds_data[chain_idx] = {}
            _data = value[chain_idx]
            cds_data[chain_idx][cds_name] = _data

    while any(key == draw_name for key in cds_data[0]):
        draw_name += "w"

    for chain_idx in cds_data:
        cds_data[chain_idx][draw_name] = data.draw.values

    cds_data = {
        chain_idx: ColumnDataSource(cds)
        for chain_idx, cds in cds_data.items()
    }

    for idx, (var_name, selection, value) in enumerate(plotters):
        value = np.atleast_2d(value)

        if len(value.shape) == 2:
            y_name = (var_name
                      if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format(
                          var_name,
                          "_".join(
                              str(item) for key, value in selection.items()
                              for item in ((key, value) if (
                                  isinstance(value, str)
                                  or not isinstance(value, Iterable)) else (
                                      key, *value))),
                      ))
            if rug:
                rug_kwargs["y"] = y_name
            _plot_chains_bokeh(
                ax_density=axes[idx, 0],
                ax_trace=axes[idx, 1],
                data=cds_data,
                x_name=draw_name,
                y_name=y_name,
                colors=colors,
                combined=combined,
                rug=rug,
                legend=legend,
                trace_kwargs=trace_kwargs,
                hist_kwargs=hist_kwargs,
                plot_kwargs=plot_kwargs,
                fill_kwargs=fill_kwargs,
                rug_kwargs=rug_kwargs,
            )
        else:
            for y_name in cds_var_groups[var_name]:
                if rug:
                    rug_kwargs["y"] = y_name
                _plot_chains_bokeh(
                    ax_density=axes[idx, 0],
                    ax_trace=axes[idx, 1],
                    data=cds_data,
                    x_name=draw_name,
                    y_name=y_name,
                    colors=colors,
                    combined=combined,
                    rug=rug,
                    legend=legend,
                    trace_kwargs=trace_kwargs,
                    hist_kwargs=hist_kwargs,
                    plot_kwargs=plot_kwargs,
                    fill_kwargs=fill_kwargs,
                    rug_kwargs=rug_kwargs,
                )

        for col in (0, 1):
            _title = Title()
            _title.text = make_label(var_name, selection)
            axes[idx, col].title = _title

        for _, _, vlines in (j for j in lines
                             if j[0] == var_name and j[1] == selection):
            if isinstance(vlines, (float, int)):
                line_values = [vlines]
            else:
                line_values = np.atleast_1d(vlines).ravel()

            for line_value in line_values:
                vline = Span(
                    location=line_value,
                    dimension="height",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=0.75,
                )
                hline = Span(
                    location=line_value,
                    dimension="width",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=trace_kwargs["alpha"],
                )

                axes[idx, 0].renderers.append(vline)
                axes[idx, 1].renderers.append(hline)

        if legend:
            for col in (0, 1):
                axes[idx, col].legend.location = "top_left"
                axes[idx, col].legend.click_policy = "hide"
        else:
            for col in (0, 1):
                if axes[idx, col].legend:
                    axes[idx, col].legend.visible = False

        if divergences:
            div_density_kwargs = {}
            div_density_kwargs.setdefault("size", 14)
            div_density_kwargs.setdefault("line_color", "red")
            div_density_kwargs.setdefault("line_width", 2)
            div_density_kwargs.setdefault("line_alpha", 0.50)
            div_density_kwargs.setdefault("angle", np.pi / 2)

            div_trace_kwargs = {}
            div_trace_kwargs.setdefault("size", 14)
            div_trace_kwargs.setdefault("line_color", "red")
            div_trace_kwargs.setdefault("line_width", 2)
            div_trace_kwargs.setdefault("line_alpha", 0.50)
            div_trace_kwargs.setdefault("angle", np.pi / 2)

            div_selection = {
                k: v
                for k, v in selection.items() if k in divergence_data.dims
            }
            divs = divergence_data.sel(**div_selection).values
            divs = np.atleast_2d(divs)

            for chain, chain_divs in enumerate(divs):
                div_idxs = np.arange(len(chain_divs))[chain_divs]
                if div_idxs.size > 0:
                    values = value[chain, div_idxs]
                    tmp_cds = ColumnDataSource({"y": values, "x": div_idxs})
                    if divergences == "top":
                        y_div_trace = value.max()
                    else:
                        y_div_trace = value.min()
                    glyph_density = Dash(x="y", y=0.0, **div_density_kwargs)
                    glyph_trace = Dash(x="x",
                                       y=y_div_trace,
                                       **div_trace_kwargs)

                    axes[idx, 0].add_glyph(tmp_cds, glyph_density)
                    axes[idx, 1].add_glyph(tmp_cds, glyph_trace)

    if show:
        grid = gridplot([list(item) for item in axes],
                        toolbar_location="above")
        bkp.show(grid)

    return axes
コード例 #14
0
ファイル: traceplot.py プロジェクト: utkarsh-maheshwari/arviz
def plot_trace(
    data,
    var_names,
    divergences,
    kind,
    figsize,
    rug,
    lines,
    circ_var_names,  # pylint: disable=unused-argument
    circ_var_units,  # pylint: disable=unused-argument
    compact,
    compact_prop,
    combined,
    chain_prop,
    legend,
    labeller,
    plot_kwargs,
    fill_kwargs,
    rug_kwargs,
    hist_kwargs,
    trace_kwargs,
    rank_kwargs,
    plotters,
    divergence_data,
    axes,
    backend_kwargs,
    backend_config,
    show,
):
    """Bokeh traceplot."""
    # If divergences are plotted they must be provided
    if divergences is not False:
        assert divergence_data is not None

    if backend_config is None:
        backend_config = {}

    backend_config = {
        **backend_kwarg_defaults(("bounds_y_range", "plot.bokeh.bounds_y_range"), ),
        **backend_config,
    }

    # Set plot default backend kwargs
    if backend_kwargs is None:
        backend_kwargs = {}

    backend_kwargs = {
        **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"), ),
        **backend_kwargs,
    }
    dpi = backend_kwargs.pop("dpi")

    if figsize is None:
        figsize = (12, len(plotters) * 2)

    figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize,
                                                     10,
                                                     rows=len(plotters),
                                                     cols=2)

    backend_kwargs.setdefault("height", int(figsize[1] * dpi // len(plotters)))
    backend_kwargs.setdefault("width", int(figsize[0] * dpi // 2))

    if lines is None:
        lines = ()

    num_chain_props = len(data.chain) + 1 if combined else len(data.chain)
    if not compact:
        chain_prop = ({
            "line_color":
            plt.rcParams["axes.prop_cycle"].by_key()["color"]
        } if chain_prop is None else chain_prop)
    else:
        chain_prop = ({
            "line_dash": ("solid", "dotted", "dashed", "dashdot"),
        } if chain_prop is None else chain_prop)
        compact_prop = ({
            "line_color":
            plt.rcParams["axes.prop_cycle"].by_key()["color"]
        } if compact_prop is None else compact_prop)

    if isinstance(chain_prop, str):
        chain_prop = {
            chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]
        }
    if isinstance(chain_prop, tuple):
        warnings.warn(
            "chain_prop as a tuple will be deprecated in a future warning, use a dict instead",
            FutureWarning,
        )
        chain_prop = {chain_prop[0]: chain_prop[1]}
    chain_prop = {
        prop_name:
        [prop for _, prop in zip(range(num_chain_props), cycle(props))]
        for prop_name, props in chain_prop.items()
    }

    if isinstance(compact_prop, str):
        compact_prop = {
            compact_prop:
            plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]
        }
    if isinstance(compact_prop, tuple):
        warnings.warn(
            "compact_prop as a tuple will be deprecated in a future warning, use a dict instead",
            FutureWarning,
        )
        compact_prop = {compact_prop[0]: compact_prop[1]}

    trace_kwargs = {} if trace_kwargs is None else trace_kwargs
    trace_kwargs.setdefault("alpha", 0.35)

    if hist_kwargs is None:
        hist_kwargs = {}
    hist_kwargs.setdefault("alpha", 0.35)

    if plot_kwargs is None:
        plot_kwargs = {}
    if fill_kwargs is None:
        fill_kwargs = {}
    if rug_kwargs is None:
        rug_kwargs = {}
    if rank_kwargs is None:
        rank_kwargs = {}

    trace_kwargs.setdefault("line_width", linewidth)
    plot_kwargs.setdefault("line_width", linewidth)

    if rank_kwargs is None:
        rank_kwargs = {}

    if axes is None:
        axes = []
        backend_kwargs_copy = backend_kwargs.copy()
        for i in range(len(plotters)):
            if not i:
                _axes = [
                    bkp.figure(**backend_kwargs),
                    bkp.figure(**backend_kwargs_copy)
                ]
                backend_kwargs_copy.setdefault("x_range", _axes[1].x_range)
            else:
                _axes = [
                    bkp.figure(**backend_kwargs),
                    bkp.figure(**backend_kwargs_copy),
                ]
            axes.append(_axes)

    axes = np.atleast_2d(axes)

    cds_data = {}
    cds_var_groups = {}
    draw_name = "draw"

    for var_name, selection, isel, value in list(
            xarray_var_iter(data, var_names=var_names, combined=True)):
        if selection:
            cds_name = "{}_ARVIZ_CDS_SELECTION_{}".format(
                var_name,
                "_".join(
                    str(item) for key, value in selection.items()
                    for item in ([key, value] if (
                        isinstance(value, str) or
                        not isinstance(value, Iterable)) else [key, *value])),
            )
        else:
            cds_name = var_name

        if var_name not in cds_var_groups:
            cds_var_groups[var_name] = []
        cds_var_groups[var_name].append(cds_name)

        for chain_idx, _ in enumerate(data.chain.values):
            if chain_idx not in cds_data:
                cds_data[chain_idx] = {}
            _data = value[chain_idx]
            cds_data[chain_idx][cds_name] = _data

    while any(key == draw_name for key in cds_data[0]):
        draw_name += "w"

    for chain_idx in cds_data:
        cds_data[chain_idx][draw_name] = data.draw.values

    cds_data = {
        chain_idx: ColumnDataSource(cds)
        for chain_idx, cds in cds_data.items()
    }

    for idx, (var_name, selection, isel, value) in enumerate(plotters):
        value = np.atleast_2d(value)

        if len(value.shape) == 2:
            y_name = (var_name
                      if not selection else "{}_ARVIZ_CDS_SELECTION_{}".format(
                          var_name,
                          "_".join(
                              str(item) for key, value in selection.items()
                              for item in ((key, value) if (
                                  isinstance(value, str)
                                  or not isinstance(value, Iterable)) else (
                                      key, *value))),
                      ))
            if rug:
                rug_kwargs["y"] = y_name
            _plot_chains_bokeh(
                ax_density=axes[idx, 0],
                ax_trace=axes[idx, 1],
                data=cds_data,
                x_name=draw_name,
                y_name=y_name,
                chain_prop=chain_prop,
                combined=combined,
                rug=rug,
                kind=kind,
                legend=legend,
                trace_kwargs=trace_kwargs,
                hist_kwargs=hist_kwargs,
                plot_kwargs=plot_kwargs,
                fill_kwargs=fill_kwargs,
                rug_kwargs=rug_kwargs,
                rank_kwargs=rank_kwargs,
            )
        else:
            for y_name in cds_var_groups[var_name]:
                if rug:
                    rug_kwargs["y"] = y_name
                _plot_chains_bokeh(
                    ax_density=axes[idx, 0],
                    ax_trace=axes[idx, 1],
                    data=cds_data,
                    x_name=draw_name,
                    y_name=y_name,
                    chain_prop=chain_prop,
                    combined=combined,
                    rug=rug,
                    kind=kind,
                    legend=legend,
                    trace_kwargs=trace_kwargs,
                    hist_kwargs=hist_kwargs,
                    plot_kwargs=plot_kwargs,
                    fill_kwargs=fill_kwargs,
                    rug_kwargs=rug_kwargs,
                    rank_kwargs=rank_kwargs,
                )

        for col in (0, 1):
            _title = Title()
            _title.text = labeller.make_label_vert(var_name, selection, isel)
            axes[idx, col].title = _title
            axes[idx, col].y_range = DataRange1d(
                bounds=backend_config["bounds_y_range"], min_interval=0.1)

        for _, _, vlines in (j for j in lines
                             if j[0] == var_name and j[1] == selection):
            if isinstance(vlines, (float, int)):
                line_values = [vlines]
            else:
                line_values = np.atleast_1d(vlines).ravel()

            for line_value in line_values:
                vline = Span(
                    location=line_value,
                    dimension="height",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=0.75,
                )
                hline = Span(
                    location=line_value,
                    dimension="width",
                    line_color="black",
                    line_width=1.5,
                    line_alpha=trace_kwargs["alpha"],
                )

                axes[idx, 0].renderers.append(vline)
                axes[idx, 1].renderers.append(hline)

        if legend:
            for col in (0, 1):
                axes[idx, col].legend.location = "top_left"
                axes[idx, col].legend.click_policy = "hide"
        else:
            for col in (0, 1):
                if axes[idx, col].legend:
                    axes[idx, col].legend.visible = False

        if divergences:
            div_density_kwargs = {}
            div_density_kwargs.setdefault("size", 14)
            div_density_kwargs.setdefault("line_color", "red")
            div_density_kwargs.setdefault("line_width", 2)
            div_density_kwargs.setdefault("line_alpha", 0.50)
            div_density_kwargs.setdefault("angle", np.pi / 2)

            div_trace_kwargs = {}
            div_trace_kwargs.setdefault("size", 14)
            div_trace_kwargs.setdefault("line_color", "red")
            div_trace_kwargs.setdefault("line_width", 2)
            div_trace_kwargs.setdefault("line_alpha", 0.50)
            div_trace_kwargs.setdefault("angle", np.pi / 2)

            div_selection = {
                k: v
                for k, v in selection.items() if k in divergence_data.dims
            }
            divs = divergence_data.sel(**div_selection).values
            divs = np.atleast_2d(divs)

            for chain, chain_divs in enumerate(divs):
                div_idxs = np.arange(len(chain_divs))[chain_divs]
                if div_idxs.size > 0:
                    values = value[chain, div_idxs]
                    tmp_cds = ColumnDataSource({"y": values, "x": div_idxs})
                    if divergences == "top":
                        y_div_trace = value.max()
                    else:
                        y_div_trace = value.min()
                    glyph_density = Dash(x="y", y=0.0, **div_density_kwargs)
                    if kind == "trace":
                        glyph_trace = Dash(x="x",
                                           y=y_div_trace,
                                           **div_trace_kwargs)
                        axes[idx, 1].add_glyph(tmp_cds, glyph_trace)

                    axes[idx, 0].add_glyph(tmp_cds, glyph_density)

    show_layout(axes, show)

    return axes