Пример #1
0
def table_view_plot(
    experiment: Experiment,
    data: Data,
    use_empirical_bayes: bool = True,
    only_data_frame: bool = False,
    arm_noun: str = "arm",
):
    """Table of means and confidence intervals.

    Table is of the form:

    +-------+------------+-----------+
    |  arm  |  metric_1  |  metric_2 |
    +=======+============+===========+
    |  0_0  | mean +- CI |    ...    |
    +-------+------------+-----------+
    |  0_1  |    ...     |    ...    |
    +-------+------------+-----------+

    """
    model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson
    model = model_func(experiment=experiment, data=data)

    # We don't want to include metrics from a collection,
    # or the chart will be too big to read easily.
    # Example:
    # experiment.metrics = {
    #   'regular_metric': Metric(),
    #   'collection_metric: CollectionMetric()', # collection_metric =[metric1, metric2]
    # }
    # model.metric_names = [regular_metric, metric1, metric2] # "exploded" out
    # We want to filter model.metric_names and get rid of metric1, metric2
    metric_names = [
        metric_name for metric_name in model.metric_names
        if metric_name in experiment.metrics
    ]

    metric_name_to_lower_is_better = {
        metric_name: experiment.metrics[metric_name].lower_is_better
        for metric_name in metric_names
    }

    plot_data, _, _ = get_plot_data(
        model=model,
        generator_runs_dict={},
        # pyre-fixme[6]: Expected `Optional[typing.Set[str]]` for 3rd param but got
        #  `List[str]`.
        metric_names=metric_names,
    )

    if plot_data.status_quo_name:
        status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name)
        rel = True
    else:
        status_quo_arm = None
        rel = False

    records = []
    colors = []
    records_with_mean = []
    records_with_ci = []
    for metric_name in metric_names:
        arm_names, _, ys, ys_se = _error_scatter_data(
            # pyre-fixme[6]: Expected
            #  `List[typing.Union[ax.plot.base.PlotInSampleArm,
            #  ax.plot.base.PlotOutOfSampleArm]]` for 1st param but got
            #  `List[ax.plot.base.PlotInSampleArm]`.
            arms=list(plot_data.in_sample.values()),
            y_axis_var=PlotMetric(metric_name, pred=True, rel=rel),
            x_axis_var=None,
            status_quo_arm=status_quo_arm,
        )

        results_by_arm = list(zip(arm_names, ys, ys_se))
        colors.append([
            get_color(
                x=y,
                ci=Z * y_se,
                rel=rel,
                # pyre-fixme[6]: Expected `bool` for 4th param but got
                #  `Optional[bool]`.
                reverse=metric_name_to_lower_is_better[metric_name],
            ) for (_, y, y_se) in results_by_arm
        ])
        records.append([
            "{:.3f} ± {:.3f}".format(y, Z * y_se)
            for (_, y, y_se) in results_by_arm
        ])
        records_with_mean.append(
            {arm_name: y
             for (arm_name, y, _) in results_by_arm})
        records_with_ci.append(
            {arm_name: Z * y_se
             for (arm_name, _, y_se) in results_by_arm})

    if only_data_frame:
        return tuple(
            pd.DataFrame.from_records(records, index=metric_names)
            for records in [records_with_mean, records_with_ci])

    def transpose(m):
        return [[m[j][i] for j in range(len(m))] for i in range(len(m[0]))]

    records = [[name.replace(":", " : ")
                for name in metric_names]] + transpose(records)
    colors = [["#ffffff"] * len(metric_names)] + transpose(colors)
    # pyre-fixme[6]: Expected `List[str]` for 1st param but got `List[float]`.
    header = [f"<b>{x}</b>" for x in [f"{arm_noun}s"] + arm_names]
    column_widths = [300] + [150] * len(arm_names)

    trace = go.Table(
        header={
            "values": header,
            "align": ["left"]
        },
        cells={
            "values": records,
            "align": ["left"],
            "fill": {
                "color": colors
            }
        },
        columnwidth=column_widths,
    )
    layout = go.Layout(
        width=sum(column_widths),
        margin=go.layout.Margin(l=0, r=20, b=20, t=20, pad=4),  # noqa E741
    )
    fig = go.Figure(data=[trace], layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
Пример #2
0
def _obs_vs_pred_dropdown_plot(
    data: PlotData,
    rel: bool,
    show_context: bool = False,
    xlabel: str = "Actual Outcome",
    ylabel: str = "Predicted Outcome",
) -> Dict[str, Any]:
    """Plot a dropdown plot of observed vs. predicted values from a model.

    Args:
        data: a name tuple storing observed and predicted data
            from a model.
        rel: if True, plot metrics relative to the status quo.
        show_context: Show context on hover.
        xlabel: Label for x-axis.
        ylabel: Label for y-axis.

    """
    traces = []
    metric_dropdown = []

    if rel and data.status_quo_name is not None:
        if show_context:
            raise ValueError(
                "This plot does not support both context and relativization at "
                "the same time.")
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        status_quo_arm = data.in_sample[data.status_quo_name]
    else:
        status_quo_arm = None

    for i, metric in enumerate(data.metrics):
        y_raw, se_raw, y_hat, se_hat = _error_scatter_data(
            # Expected `List[typing.Union[PlotInSampleArm,
            # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous
            # parameter to call `ax.plot.scatter._error_scatter_data` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(data.in_sample.values()),
            y_axis_var=PlotMetric(metric, True),
            x_axis_var=PlotMetric(metric, False),
            rel=rel,
            status_quo_arm=status_quo_arm,
        )
        min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw or [],
                                              se_hat)
        traces.append(_diagonal_trace(min_, max_, visible=(i == 0)))
        traces.append(
            _error_scatter_trace(
                # Expected `List[typing.Union[PlotInSampleArm,
                # ax.plot.base.PlotOutOfSampleArm]]` for 1st parameter
                # `arms` to call `ax.plot.scatter._error_scatter_trace`
                # but got `List[PlotInSampleArm]`.
                # pyre-fixme[6]:
                arms=list(data.in_sample.values()),
                hoverinfo="text",
                rel=rel,
                show_arm_details_on_hover=True,
                show_CI=True,
                show_context=show_context,
                status_quo_arm=status_quo_arm,
                visible=(i == 0),
                x_axis_label=xlabel,
                x_axis_var=PlotMetric(metric, False),
                y_axis_label=ylabel,
                y_axis_var=PlotMetric(metric, True),
            ))

        # only the first two traces are visible (corresponding to first outcome
        # in dropdown)
        is_visible = [False] * (len(data.metrics) * 2)
        is_visible[2 * i] = True
        is_visible[2 * i + 1] = True

        # on dropdown change, restyle
        metric_dropdown.append({
            "args": ["visible", is_visible],
            "label": metric,
            "method": "restyle"
        })

    updatemenus = [
        {
            "x": 0,
            "y": 1.125,
            "yanchor": "top",
            "xanchor": "left",
            "buttons": metric_dropdown,
        },
        {
            "buttons": [
                {
                    "args": [{
                        "error_x.width": 4,
                        "error_x.thickness": 2,
                        "error_y.width": 4,
                        "error_y.thickness": 2,
                    }],
                    "label":
                    "Yes",
                    "method":
                    "restyle",
                },
                {
                    "args": [{
                        "error_x.width": 0,
                        "error_x.thickness": 0,
                        "error_y.width": 0,
                        "error_y.thickness": 0,
                    }],
                    "label":
                    "No",
                    "method":
                    "restyle",
                },
            ],
            "x":
            1.125,
            "xanchor":
            "left",
            "y":
            0.8,
            "yanchor":
            "middle",
        },
    ]

    layout = go.Layout(
        annotations=[{
            "showarrow": False,
            "text": "Show CI",
            "x": 1.125,
            "xanchor": "left",
            "xref": "paper",
            "y": 0.9,
            "yanchor": "middle",
            "yref": "paper",
        }],
        xaxis={
            "title": xlabel,
            "zeroline": False,
            "mirror": True,
            "linecolor": "black",
            "linewidth": 0.5,
        },
        yaxis={
            "title": ylabel,
            "zeroline": False,
            "mirror": True,
            "linecolor": "black",
            "linewidth": 0.5,
        },
        showlegend=False,
        hovermode="closest",
        updatemenus=updatemenus,
        width=530,
        height=500,
    )

    return go.Figure(data=traces, layout=layout)
Пример #3
0
def table_view_plot(
    experiment: Experiment,
    data: Data,
    use_empirical_bayes: bool = True,
    only_data_frame: bool = False,
):
    """Table of means and confidence intervals.

    Table is of the form:

    +-------+------------+-----------+
    |  arm  |  metric_1  |  metric_2 |
    +=======+============+===========+
    |  0_0  | mean +- CI |    ...    |
    +-------+------------+-----------+
    |  0_1  |    ...     |    ...    |
    +-------+------------+-----------+

    """
    model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson
    model = model_func(experiment=experiment, data=data)
    metric_name_to_lower_is_better = {
        metric.name: metric.lower_is_better
        for metric in experiment.metrics.values()
    }

    plot_data, _, _ = get_plot_data(model=model,
                                    generator_runs_dict={},
                                    metric_names=model.metric_names)

    if plot_data.status_quo_name:
        status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name)
        rel = True
    else:
        status_quo_arm = None
        rel = False

    results = {}
    records_with_mean = []
    records_with_ci = []
    for metric_name in model.metric_names:
        arms, _, ys, ys_se = _error_scatter_data(
            arms=list(plot_data.in_sample.values()),
            y_axis_var=PlotMetric(metric_name, True),
            x_axis_var=None,
            rel=rel,
            status_quo_arm=status_quo_arm,
        )
        # results[metric] will hold a list of tuples, one tuple per arm
        tuples = list(zip(arms, ys, ys_se))
        results[metric_name] = tuples
        # used if only_data_frame == True
        records_with_mean.append({arm: y for (arm, y, _) in tuples})
        records_with_ci.append({arm: y_se for (arm, _, y_se) in tuples})

    if only_data_frame:
        return tuple(
            pd.DataFrame.from_records(records,
                                      index=model.metric_names).transpose()
            for records in [records_with_mean, records_with_ci])

    # cells and colors are both lists of lists
    # each top-level list corresponds to a column,
    # so the first is a list of arms
    cells = [[f"<b>{x}</b>" for x in arms]]
    colors = [["#ffffff"] * len(arms)]
    metric_names = []
    for metric_name, list_of_tuples in sorted(results.items()):
        cells.append([
            "{:.3f} &plusmn; {:.3f}".format(y, Z * y_se)
            for (_, y, y_se) in list_of_tuples
        ])
        metric_names.append(metric_name.replace(":", " : "))

        color_vec = []
        for (_, y, y_se) in list_of_tuples:
            color_vec.append(
                get_color(
                    x=y,
                    ci=Z * y_se,
                    rel=rel,
                    reverse=metric_name_to_lower_is_better[metric_name],
                ))
        colors.append(color_vec)

    header = ["arms"] + metric_names
    header = [f"<b>{x}</b>" for x in header]
    trace = go.Table(
        header={
            "values": header,
            "align": ["left"]
        },
        cells={
            "values": cells,
            "align": ["left"],
            "fill": {
                "color": colors
            }
        },
    )
    layout = go.Layout(
        height=min([400, len(arms) * 20 + 200]),
        width=175 * len(header),
        margin=go.Margin(l=0, r=20, b=20, t=20, pad=4),  # noqa E741
    )
    fig = go.Figure(data=[trace], layout=layout)
    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)