Пример #1
0
def tile_cross_validation(
    cv_results: List[CVResult],
    show_arm_details_on_hover: bool = True,
    show_context: bool = True,
) -> AxPlotConfig:
    """Tile version of CV plots; sorted by 'best fitting' outcomes.

    Plots are sorted in decreasing order using the p-value of a Fisher exact
    test statistic.

    Args:
        cv_results: cross-validation results.
        include_measurement_error: if True, include
            measurement_error metrics in plot.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is True.
        show_context: if True (default), display context on
            hover.

    """
    data = _get_cv_plot_data(cv_results)
    metrics = data.metrics

    # make subplots (2 plots per row)
    nrows = int(np.ceil(len(metrics) / 2))
    ncols = min(len(metrics), 2)
    fig = tools.make_subplots(
        rows=nrows,
        cols=ncols,
        print_grid=False,
        subplot_titles=tuple(metrics),
        horizontal_spacing=0.15,
        vertical_spacing=0.30 / nrows,
    )

    for i, metric in enumerate(metrics):
        y_hat = []
        se_hat = []
        y_raw = []
        se_raw = []
        for arm in data.in_sample.values():
            y_hat.append(arm.y_hat[metric])
            se_hat.append(arm.se_hat[metric])
            y_raw.append(arm.y[metric])
            se_raw.append(arm.se[metric])
        min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat)
        fig.append_trace(_diagonal_trace(min_, max_),
                         int(np.floor(i / 2)) + 1, i % 2 + 1)
        fig.append_trace(
            _error_scatter_trace(
                # Expected `List[typing.Union[PlotInSampleArm,
                # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous
                # parameter to call `ax.plot.scatter._error_scatter_trace` but
                # got `List[PlotInSampleArm]`.
                # pyre-fixme[6]:
                list(data.in_sample.values()),
                y_axis_var=PlotMetric(metric, True),
                x_axis_var=PlotMetric(metric, False),
                y_axis_label="Predicted",
                x_axis_label="Actual",
                hoverinfo="text",
                show_arm_details_on_hover=show_arm_details_on_hover,
                show_context=show_context,
            ),
            int(np.floor(i / 2)) + 1,
            i % 2 + 1,
        )

    # if odd number of plots, need to manually remove the last blank subplot
    # generated by `tools.make_subplots`
    if len(metrics) % 2 == 1:
        del fig["layout"]["xaxis{}".format(nrows * ncols)]
        del fig["layout"]["yaxis{}".format(nrows * ncols)]

    # allocate 400 px per plot (equal aspect ratio)
    fig["layout"].update(
        title="Cross-Validation",  # What should I replace this with?
        hovermode="closest",
        width=800,
        height=400 * nrows,
        font={"size": 10},
        showlegend=False,
    )

    # update subplot title size and the axis labels
    for i, ant in enumerate(fig["layout"]["annotations"]):
        ant["font"].update(size=12)
        fig["layout"]["xaxis{}".format(i + 1)].update(title="Actual Outcome",
                                                      mirror=True,
                                                      linecolor="black",
                                                      linewidth=0.5)
        fig["layout"]["yaxis{}".format(i + 1)].update(
            title="Predicted Outcome",
            mirror=True,
            linecolor="black",
            linewidth=0.5)

    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
Пример #2
0
def _obs_vs_pred_dropdown_plot(
    data: PlotData,
    rel: bool,
    show_context: bool = False,
    xlabel: str = "Actual Outcome",
    ylabel: str = "Predicted Outcome",
) -> Dict[str, Any]:
    """Plot a dropdown plot of observed vs. predicted values from a model.

    Args:
        data: a name tuple storing observed and predicted data
            from a model.
        rel: if True, plot metrics relative to the status quo.
        show_context: Show context on hover.
        xlabel: Label for x-axis.
        ylabel: Label for y-axis.

    """
    traces = []
    metric_dropdown = []

    if rel and data.status_quo_name is not None:
        if show_context:
            raise ValueError(
                "This plot does not support both context and relativization at "
                "the same time.")
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        status_quo_arm = data.in_sample[data.status_quo_name]
    else:
        status_quo_arm = None

    for i, metric in enumerate(data.metrics):
        y_raw, se_raw, y_hat, se_hat = _error_scatter_data(
            # Expected `List[typing.Union[PlotInSampleArm,
            # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous
            # parameter to call `ax.plot.scatter._error_scatter_data` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(data.in_sample.values()),
            y_axis_var=PlotMetric(metric, True),
            x_axis_var=PlotMetric(metric, False),
            rel=rel,
            status_quo_arm=status_quo_arm,
        )
        min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw or [],
                                              se_hat)
        traces.append(_diagonal_trace(min_, max_, visible=(i == 0)))
        traces.append(
            _error_scatter_trace(
                # Expected `List[typing.Union[PlotInSampleArm,
                # ax.plot.base.PlotOutOfSampleArm]]` for 1st parameter
                # `arms` to call `ax.plot.scatter._error_scatter_trace`
                # but got `List[PlotInSampleArm]`.
                # pyre-fixme[6]:
                arms=list(data.in_sample.values()),
                hoverinfo="text",
                rel=rel,
                show_arm_details_on_hover=True,
                show_CI=True,
                show_context=show_context,
                status_quo_arm=status_quo_arm,
                visible=(i == 0),
                x_axis_label=xlabel,
                x_axis_var=PlotMetric(metric, False),
                y_axis_label=ylabel,
                y_axis_var=PlotMetric(metric, True),
            ))

        # only the first two traces are visible (corresponding to first outcome
        # in dropdown)
        is_visible = [False] * (len(data.metrics) * 2)
        is_visible[2 * i] = True
        is_visible[2 * i + 1] = True

        # on dropdown change, restyle
        metric_dropdown.append({
            "args": ["visible", is_visible],
            "label": metric,
            "method": "restyle"
        })

    updatemenus = [
        {
            "x": 0,
            "y": 1.125,
            "yanchor": "top",
            "xanchor": "left",
            "buttons": metric_dropdown,
        },
        {
            "buttons": [
                {
                    "args": [{
                        "error_x.width": 4,
                        "error_x.thickness": 2,
                        "error_y.width": 4,
                        "error_y.thickness": 2,
                    }],
                    "label":
                    "Yes",
                    "method":
                    "restyle",
                },
                {
                    "args": [{
                        "error_x.width": 0,
                        "error_x.thickness": 0,
                        "error_y.width": 0,
                        "error_y.thickness": 0,
                    }],
                    "label":
                    "No",
                    "method":
                    "restyle",
                },
            ],
            "x":
            1.125,
            "xanchor":
            "left",
            "y":
            0.8,
            "yanchor":
            "middle",
        },
    ]

    layout = go.Layout(
        annotations=[{
            "showarrow": False,
            "text": "Show CI",
            "x": 1.125,
            "xanchor": "left",
            "xref": "paper",
            "y": 0.9,
            "yanchor": "middle",
            "yref": "paper",
        }],
        xaxis={
            "title": xlabel,
            "zeroline": False,
            "mirror": True,
            "linecolor": "black",
            "linewidth": 0.5,
        },
        yaxis={
            "title": ylabel,
            "zeroline": False,
            "mirror": True,
            "linecolor": "black",
            "linewidth": 0.5,
        },
        showlegend=False,
        hovermode="closest",
        updatemenus=updatemenus,
        width=530,
        height=500,
    )

    return go.Figure(data=traces, layout=layout)
Пример #3
0
def table_view_plot(
    experiment: Experiment,
    data: Data,
    use_empirical_bayes: bool = True,
    only_data_frame: bool = False,
    arm_noun: str = "arm",
):
    """Table of means and confidence intervals.

    Table is of the form:

    +-------+------------+-----------+
    |  arm  |  metric_1  |  metric_2 |
    +=======+============+===========+
    |  0_0  | mean +- CI |    ...    |
    +-------+------------+-----------+
    |  0_1  |    ...     |    ...    |
    +-------+------------+-----------+

    """
    model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson
    model = model_func(experiment=experiment, data=data)

    # We don't want to include metrics from a collection,
    # or the chart will be too big to read easily.
    # Example:
    # experiment.metrics = {
    #   'regular_metric': Metric(),
    #   'collection_metric: CollectionMetric()', # collection_metric =[metric1, metric2]
    # }
    # model.metric_names = [regular_metric, metric1, metric2] # "exploded" out
    # We want to filter model.metric_names and get rid of metric1, metric2
    metric_names = [
        metric_name for metric_name in model.metric_names
        if metric_name in experiment.metrics
    ]

    metric_name_to_lower_is_better = {
        metric_name: experiment.metrics[metric_name].lower_is_better
        for metric_name in metric_names
    }

    plot_data, _, _ = get_plot_data(
        model=model,
        generator_runs_dict={},
        # pyre-fixme[6]: Expected `Optional[typing.Set[str]]` for 3rd param but got
        #  `List[str]`.
        metric_names=metric_names,
    )

    if plot_data.status_quo_name:
        status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name)
        rel = True
    else:
        status_quo_arm = None
        rel = False

    records = []
    colors = []
    records_with_mean = []
    records_with_ci = []
    for metric_name in metric_names:
        arm_names, _, ys, ys_se = _error_scatter_data(
            # pyre-fixme[6]: Expected
            #  `List[typing.Union[ax.plot.base.PlotInSampleArm,
            #  ax.plot.base.PlotOutOfSampleArm]]` for 1st param but got
            #  `List[ax.plot.base.PlotInSampleArm]`.
            arms=list(plot_data.in_sample.values()),
            y_axis_var=PlotMetric(metric_name, pred=True, rel=rel),
            x_axis_var=None,
            status_quo_arm=status_quo_arm,
        )

        results_by_arm = list(zip(arm_names, ys, ys_se))
        colors.append([
            get_color(
                x=y,
                ci=Z * y_se,
                rel=rel,
                # pyre-fixme[6]: Expected `bool` for 4th param but got
                #  `Optional[bool]`.
                reverse=metric_name_to_lower_is_better[metric_name],
            ) for (_, y, y_se) in results_by_arm
        ])
        records.append([
            "{:.3f} ± {:.3f}".format(y, Z * y_se)
            for (_, y, y_se) in results_by_arm
        ])
        records_with_mean.append(
            {arm_name: y
             for (arm_name, y, _) in results_by_arm})
        records_with_ci.append(
            {arm_name: Z * y_se
             for (arm_name, _, y_se) in results_by_arm})

    if only_data_frame:
        return tuple(
            pd.DataFrame.from_records(records, index=metric_names)
            for records in [records_with_mean, records_with_ci])

    def transpose(m):
        return [[m[j][i] for j in range(len(m))] for i in range(len(m[0]))]

    records = [[name.replace(":", " : ")
                for name in metric_names]] + transpose(records)
    colors = [["#ffffff"] * len(metric_names)] + transpose(colors)
    # pyre-fixme[6]: Expected `List[str]` for 1st param but got `List[float]`.
    header = [f"<b>{x}</b>" for x in [f"{arm_noun}s"] + arm_names]
    column_widths = [300] + [150] * len(arm_names)

    trace = go.Table(
        header={
            "values": header,
            "align": ["left"]
        },
        cells={
            "values": records,
            "align": ["left"],
            "fill": {
                "color": colors
            }
        },
        columnwidth=column_widths,
    )
    layout = go.Layout(
        width=sum(column_widths),
        margin=go.layout.Margin(l=0, r=20, b=20, t=20, pad=4),  # noqa E741
    )
    fig = go.Figure(data=[trace], layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
Пример #4
0
def _single_metric_traces(
    model: ModelBridge,
    metric: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    rel: bool,
    show_arm_details_on_hover: bool = True,
    showlegend: bool = True,
    show_CI: bool = True,
    arm_noun: str = "arm",
    fixed_features: Optional[ObservationFeatures] = None,
) -> Traces:
    """Plot scatterplots with errors for a single metric (y-axis).

    Arms are plotted on the x-axis.

    Args:
        model: model to draw predictions from.
        metric: name of metric to plot.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel: if True, plot relative predictions.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is True.
        show_legend: if True, show legend for trace.
        show_CI: if True, render confidence intervals.
        arm_noun: noun to use instead of "arm" (e.g. group)
        fixed_features: Fixed features to use when making model predictions.

    """
    plot_data, _, _ = get_plot_data(model,
                                    generator_runs_dict or {}, {metric},
                                    fixed_features=fixed_features)

    status_quo_arm = (
        None if plot_data.status_quo_name is None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        else plot_data.in_sample.get(plot_data.status_quo_name))

    traces = [
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=None,
            y_axis_var=PlotMetric(metric, pred=True, rel=rel),
            status_quo_arm=status_quo_arm,
            legendgroup="In-sample",
            showlegend=showlegend,
            show_arm_details_on_hover=show_arm_details_on_hover,
            show_CI=show_CI,
            arm_noun=arm_noun,
        )
    ]

    # Candidates
    for i, (generator_run_name, cand_arms) in enumerate(
        (plot_data.out_of_sample or {}).items(), start=1):
        traces.append(
            _error_scatter_trace(
                list(cand_arms.values()),
                x_axis_var=None,
                y_axis_var=PlotMetric(metric, pred=True, rel=rel),
                status_quo_arm=status_quo_arm,
                name=generator_run_name,
                color=DISCRETE_COLOR_SCALE[i],
                legendgroup=generator_run_name,
                showlegend=showlegend,
                show_arm_details_on_hover=show_arm_details_on_hover,
                show_CI=show_CI,
                arm_noun=arm_noun,
            ))
    return traces
Пример #5
0
def lattice_multiple_metrics(
    model: ModelBridge,
    generator_runs_dict: TNullableGeneratorRunsDict = None,
    rel: bool = True,
    show_arm_details_on_hover: bool = False,
) -> AxPlotConfig:
    """Plot raw values or predictions of combinations of two metrics for arms.

    Args:
        model: model to draw predictions from.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel: if True, use relative effects. Default is True.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is False.

    """
    metrics = model.metric_names
    fig = tools.make_subplots(
        rows=len(metrics),
        cols=len(metrics),
        print_grid=False,
        shared_xaxes=False,
        shared_yaxes=False,
    )

    plot_data, _, _ = get_plot_data(
        model, generator_runs_dict if generator_runs_dict is not None else {},
        metrics)
    status_quo_arm = (
        None if plot_data.status_quo_name is None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        else plot_data.in_sample.get(plot_data.status_quo_name))

    # iterate over all combinations of metrics and generate scatter traces
    for i, o1 in enumerate(metrics, start=1):
        for j, o2 in enumerate(metrics, start=1):
            if o1 != o2:
                # in-sample observed and predicted
                obs_insample_trace = _error_scatter_trace(
                    # Expected `List[Union[PlotInSampleArm,
                    # PlotOutOfSampleArm]]` for 1st anonymous parameter to call
                    # `ax.plot.scatter._error_scatter_trace` but got
                    # `List[PlotInSampleArm]`.
                    # pyre-fixme[6]:
                    list(plot_data.in_sample.values()),
                    x_axis_var=PlotMetric(o1, pred=False, rel=rel),
                    y_axis_var=PlotMetric(o2, pred=False, rel=rel),
                    status_quo_arm=status_quo_arm,
                    showlegend=(i == 1 and j == 2),
                    legendgroup="In-sample",
                    visible=False,
                    show_arm_details_on_hover=show_arm_details_on_hover,
                )
                predicted_insample_trace = _error_scatter_trace(
                    # Expected `List[Union[PlotInSampleArm,
                    # PlotOutOfSampleArm]]` for 1st anonymous parameter to call
                    # `ax.plot.scatter._error_scatter_trace` but got
                    # `List[PlotInSampleArm]`.
                    # pyre-fixme[6]:
                    list(plot_data.in_sample.values()),
                    x_axis_var=PlotMetric(o1, pred=True, rel=rel),
                    y_axis_var=PlotMetric(o2, pred=True, rel=rel),
                    status_quo_arm=status_quo_arm,
                    legendgroup="In-sample",
                    showlegend=(i == 1 and j == 2),
                    visible=True,
                    show_arm_details_on_hover=show_arm_details_on_hover,
                )
                fig.append_trace(obs_insample_trace, j, i)
                fig.append_trace(predicted_insample_trace, j, i)

                # iterate over models here
                for k, (generator_run_name, cand_arms) in enumerate(
                    (plot_data.out_of_sample or {}).items(), start=1):
                    fig.append_trace(
                        _error_scatter_trace(
                            list(cand_arms.values()),
                            x_axis_var=PlotMetric(o1, pred=True, rel=rel),
                            y_axis_var=PlotMetric(o2, pred=True, rel=rel),
                            status_quo_arm=status_quo_arm,
                            name=generator_run_name,
                            color=DISCRETE_COLOR_SCALE[k],
                            showlegend=(i == 1 and j == 2),
                            legendgroup=generator_run_name,
                            show_arm_details_on_hover=show_arm_details_on_hover,
                        ),
                        j,
                        i,
                    )
            else:
                # if diagonal is set to True, add box plots
                fig.append_trace(
                    go.Box(
                        y=[arm.y[o1] for arm in plot_data.in_sample.values()],
                        name=None,
                        marker={"color": rgba(COLORS.STEELBLUE.value)},
                        showlegend=False,
                        legendgroup="In-sample",
                        visible=False,
                        hoverinfo="none",
                    ),
                    j,
                    i,
                )
                fig.append_trace(
                    go.Box(
                        y=[
                            arm.y_hat[o1]
                            for arm in plot_data.in_sample.values()
                        ],
                        name=None,
                        marker={"color": rgba(COLORS.STEELBLUE.value)},
                        showlegend=False,
                        legendgroup="In-sample",
                        hoverinfo="none",
                    ),
                    j,
                    i,
                )

                for k, (generator_run_name, cand_arms) in enumerate(
                    (plot_data.out_of_sample or {}).items(), start=1):
                    fig.append_trace(
                        go.Box(
                            y=[arm.y_hat[o1] for arm in cand_arms.values()],
                            name=None,
                            marker={"color": rgba(DISCRETE_COLOR_SCALE[k])},
                            showlegend=False,
                            legendgroup=generator_run_name,
                            hoverinfo="none",
                        ),
                        j,
                        i,
                    )

    fig["layout"].update(
        height=800,
        width=960,
        font={"size": 10},
        hovermode="closest",
        legend={
            "orientation": "h",
            "x": 0,
            "y": 1.05,
            "xanchor": "left",
            "yanchor": "middle",
        },
        updatemenus=[
            {
                "x":
                0.35,
                "y":
                1.08,
                "xanchor":
                "left",
                "yanchor":
                "middle",
                "buttons": [
                    {
                        "args": [{
                            "error_x.width": 0,
                            "error_x.thickness": 0,
                            "error_y.width": 0,
                            "error_y.thickness": 0,
                        }],
                        "label":
                        "No",
                        "method":
                        "restyle",
                    },
                    {
                        "args": [{
                            "error_x.width": 4,
                            "error_x.thickness": 2,
                            "error_y.width": 4,
                            "error_y.thickness": 2,
                        }],
                        "label":
                        "Yes",
                        "method":
                        "restyle",
                    },
                ],
            },
            {
                "x":
                0.1,
                "y":
                1.08,
                "xanchor":
                "left",
                "yanchor":
                "middle",
                "buttons": [
                    {
                        "args": [{
                            "visible":
                            (([False, True] +
                              [True] * len(plot_data.out_of_sample or {})) *
                             (len(metrics)**2))
                        }],
                        "label":
                        "Modeled",
                        "method":
                        "restyle",
                    },
                    {
                        "args": [{
                            "visible":
                            (([True, False] +
                              [False] * len(plot_data.out_of_sample or {})) *
                             (len(metrics)**2))
                        }],
                        "label":
                        "In-sample",
                        "method":
                        "restyle",
                    },
                ],
            },
        ],
        annotations=[
            {
                "x": 0.02,
                "y": 1.1,
                "xref": "paper",
                "yref": "paper",
                "text": "Type",
                "showarrow": False,
                "yanchor": "middle",
                "xanchor": "left",
            },
            {
                "x": 0.30,
                "y": 1.1,
                "xref": "paper",
                "yref": "paper",
                "text": "Show CI",
                "showarrow": False,
                "yanchor": "middle",
                "xanchor": "left",
            },
        ],
    )

    # add metric names to axes - add to each subplot if boxplots on the
    # diagonal and axes are not shared; else, add to the leftmost y-axes
    # and bottom x-axes.
    for i, o in enumerate(metrics):
        pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1
        pos_y = 1 + (len(metrics) * i)
        fig["layout"]["xaxis{}".format(pos_x)].update(title=_wrap_metric(o),
                                                      titlefont={"size": 10})
        fig["layout"]["yaxis{}".format(pos_y)].update(title=_wrap_metric(o),
                                                      titlefont={"size": 10})

    # do not put x-axis ticks for boxplots
    boxplot_xaxes = []
    for trace in fig["data"]:
        if trace["type"] == "box":
            # stores the xaxes which correspond to boxplot subplots
            # since we use xaxis1, xaxis2, etc, in plotly.py
            boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:]))
        else:
            # clear all error bars since default is no CI
            trace["error_x"].update(width=0, thickness=0)
            trace["error_y"].update(width=0, thickness=0)
    for xaxis in boxplot_xaxes:
        fig["layout"][xaxis]["showticklabels"] = False

    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
Пример #6
0
def _multiple_metric_traces(
    model: ModelBridge,
    metric_x: str,
    metric_y: str,
    generator_runs_dict: TNullableGeneratorRunsDict,
    rel_x: bool,
    rel_y: bool,
    fixed_features: Optional[ObservationFeatures] = None,
) -> Traces:
    """Plot traces for multiple metrics given a model and metrics.

    Args:
        model: model to draw predictions from.
        metric_x: metric to plot on the x-axis.
        metric_y: metric to plot on the y-axis.
        generator_runs_dict: a mapping from
            generator run name to generator run.
        rel_x: if True, use relative effects on metric_x.
        rel_y: if True, use relative effects on metric_y.
        fixed_features: Fixed features to use when making model predictions.

    """
    plot_data, _, _ = get_plot_data(
        model,
        generator_runs_dict if generator_runs_dict is not None else {},
        {metric_x, metric_y},
        fixed_features=fixed_features,
    )

    status_quo_arm = (
        None if plot_data.status_quo_name is None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        else plot_data.in_sample.get(plot_data.status_quo_name))

    traces = [
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=PlotMetric(metric_x, pred=False, rel=rel_x),
            y_axis_var=PlotMetric(metric_y, pred=False, rel=rel_y),
            status_quo_arm=status_quo_arm,
            visible=False,
        ),
        _error_scatter_trace(
            # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]`
            # for 1st anonymous parameter to call
            # `ax.plot.scatter._error_scatter_trace` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(plot_data.in_sample.values()),
            x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
            y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
            status_quo_arm=status_quo_arm,
            visible=True,
        ),
    ]

    for i, (generator_run_name, cand_arms) in enumerate(
        (plot_data.out_of_sample or {}).items(), start=1):
        traces.append(
            _error_scatter_trace(
                list(cand_arms.values()),
                x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x),
                y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y),
                status_quo_arm=status_quo_arm,
                name=generator_run_name,
                color=DISCRETE_COLOR_SCALE[i],
            ))
    return traces
Пример #7
0
def table_view_plot(
    experiment: Experiment,
    data: Data,
    use_empirical_bayes: bool = True,
    only_data_frame: bool = False,
):
    """Table of means and confidence intervals.

    Table is of the form:

    +-------+------------+-----------+
    |  arm  |  metric_1  |  metric_2 |
    +=======+============+===========+
    |  0_0  | mean +- CI |    ...    |
    +-------+------------+-----------+
    |  0_1  |    ...     |    ...    |
    +-------+------------+-----------+

    """
    model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson
    model = model_func(experiment=experiment, data=data)
    metric_name_to_lower_is_better = {
        metric.name: metric.lower_is_better
        for metric in experiment.metrics.values()
    }

    plot_data, _, _ = get_plot_data(model=model,
                                    generator_runs_dict={},
                                    metric_names=model.metric_names)

    if plot_data.status_quo_name:
        status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name)
        rel = True
    else:
        status_quo_arm = None
        rel = False

    results = {}
    records_with_mean = []
    records_with_ci = []
    for metric_name in model.metric_names:
        arms, _, ys, ys_se = _error_scatter_data(
            arms=list(plot_data.in_sample.values()),
            y_axis_var=PlotMetric(metric_name, True),
            x_axis_var=None,
            rel=rel,
            status_quo_arm=status_quo_arm,
        )
        # results[metric] will hold a list of tuples, one tuple per arm
        tuples = list(zip(arms, ys, ys_se))
        results[metric_name] = tuples
        # used if only_data_frame == True
        records_with_mean.append({arm: y for (arm, y, _) in tuples})
        records_with_ci.append({arm: y_se for (arm, _, y_se) in tuples})

    if only_data_frame:
        return tuple(
            pd.DataFrame.from_records(records,
                                      index=model.metric_names).transpose()
            for records in [records_with_mean, records_with_ci])

    # cells and colors are both lists of lists
    # each top-level list corresponds to a column,
    # so the first is a list of arms
    cells = [[f"<b>{x}</b>" for x in arms]]
    colors = [["#ffffff"] * len(arms)]
    metric_names = []
    for metric_name, list_of_tuples in sorted(results.items()):
        cells.append([
            "{:.3f} &plusmn; {:.3f}".format(y, Z * y_se)
            for (_, y, y_se) in list_of_tuples
        ])
        metric_names.append(metric_name.replace(":", " : "))

        color_vec = []
        for (_, y, y_se) in list_of_tuples:
            color_vec.append(
                get_color(
                    x=y,
                    ci=Z * y_se,
                    rel=rel,
                    reverse=metric_name_to_lower_is_better[metric_name],
                ))
        colors.append(color_vec)

    header = ["arms"] + metric_names
    header = [f"<b>{x}</b>" for x in header]
    trace = go.Table(
        header={
            "values": header,
            "align": ["left"]
        },
        cells={
            "values": cells,
            "align": ["left"],
            "fill": {
                "color": colors
            }
        },
    )
    layout = go.Layout(
        height=min([400, len(arms) * 20 + 200]),
        width=175 * len(header),
        margin=go.Margin(l=0, r=20, b=20, t=20, pad=4),  # noqa E741
    )
    fig = go.Figure(data=[trace], layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)