Beispiel #1
0
def _obs_vs_pred_dropdown_plot(
    data: PlotData,
    rel: bool,
    show_context: bool = False,
    xlabel: str = "Actual Outcome",
    ylabel: str = "Predicted Outcome",
) -> Dict[str, Any]:
    """Plot a dropdown plot of observed vs. predicted values from a model.

    Args:
        data: a name tuple storing observed and predicted data
            from a model.
        rel: if True, plot metrics relative to the status quo.
        show_context: Show context on hover.
        xlabel: Label for x-axis.
        ylabel: Label for y-axis.

    """
    traces = []
    metric_dropdown = []

    if rel and data.status_quo_name is not None:
        if show_context:
            raise ValueError(
                "This plot does not support both context and relativization at "
                "the same time.")
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        status_quo_arm = data.in_sample[data.status_quo_name]
    else:
        status_quo_arm = None

    for i, metric in enumerate(data.metrics):
        y_raw, se_raw, y_hat, se_hat = _error_scatter_data(
            # Expected `List[typing.Union[PlotInSampleArm,
            # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous
            # parameter to call `ax.plot.scatter._error_scatter_data` but got
            # `List[PlotInSampleArm]`.
            # pyre-fixme[6]:
            list(data.in_sample.values()),
            y_axis_var=PlotMetric(metric, True),
            x_axis_var=PlotMetric(metric, False),
            rel=rel,
            status_quo_arm=status_quo_arm,
        )
        min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw or [],
                                              se_hat)
        traces.append(_diagonal_trace(min_, max_, visible=(i == 0)))
        traces.append(
            _error_scatter_trace(
                # Expected `List[typing.Union[PlotInSampleArm,
                # ax.plot.base.PlotOutOfSampleArm]]` for 1st parameter
                # `arms` to call `ax.plot.scatter._error_scatter_trace`
                # but got `List[PlotInSampleArm]`.
                # pyre-fixme[6]:
                arms=list(data.in_sample.values()),
                hoverinfo="text",
                rel=rel,
                show_arm_details_on_hover=True,
                show_CI=True,
                show_context=show_context,
                status_quo_arm=status_quo_arm,
                visible=(i == 0),
                x_axis_label=xlabel,
                x_axis_var=PlotMetric(metric, False),
                y_axis_label=ylabel,
                y_axis_var=PlotMetric(metric, True),
            ))

        # only the first two traces are visible (corresponding to first outcome
        # in dropdown)
        is_visible = [False] * (len(data.metrics) * 2)
        is_visible[2 * i] = True
        is_visible[2 * i + 1] = True

        # on dropdown change, restyle
        metric_dropdown.append({
            "args": ["visible", is_visible],
            "label": metric,
            "method": "restyle"
        })

    updatemenus = [
        {
            "x": 0,
            "y": 1.125,
            "yanchor": "top",
            "xanchor": "left",
            "buttons": metric_dropdown,
        },
        {
            "buttons": [
                {
                    "args": [{
                        "error_x.width": 4,
                        "error_x.thickness": 2,
                        "error_y.width": 4,
                        "error_y.thickness": 2,
                    }],
                    "label":
                    "Yes",
                    "method":
                    "restyle",
                },
                {
                    "args": [{
                        "error_x.width": 0,
                        "error_x.thickness": 0,
                        "error_y.width": 0,
                        "error_y.thickness": 0,
                    }],
                    "label":
                    "No",
                    "method":
                    "restyle",
                },
            ],
            "x":
            1.125,
            "xanchor":
            "left",
            "y":
            0.8,
            "yanchor":
            "middle",
        },
    ]

    layout = go.Layout(
        annotations=[{
            "showarrow": False,
            "text": "Show CI",
            "x": 1.125,
            "xanchor": "left",
            "xref": "paper",
            "y": 0.9,
            "yanchor": "middle",
            "yref": "paper",
        }],
        xaxis={
            "title": xlabel,
            "zeroline": False,
            "mirror": True,
            "linecolor": "black",
            "linewidth": 0.5,
        },
        yaxis={
            "title": ylabel,
            "zeroline": False,
            "mirror": True,
            "linecolor": "black",
            "linewidth": 0.5,
        },
        showlegend=False,
        hovermode="closest",
        updatemenus=updatemenus,
        width=530,
        height=500,
    )

    return go.Figure(data=traces, layout=layout)
Beispiel #2
0
def tile_cross_validation(
    cv_results: List[CVResult],
    show_arm_details_on_hover: bool = True,
    show_context: bool = True,
) -> AxPlotConfig:
    """Tile version of CV plots; sorted by 'best fitting' outcomes.

    Plots are sorted in decreasing order using the p-value of a Fisher exact
    test statistic.

    Args:
        cv_results: cross-validation results.
        include_measurement_error: if True, include
            measurement_error metrics in plot.
        show_arm_details_on_hover: if True, display
            parameterizations of arms on hover. Default is True.
        show_context: if True (default), display context on
            hover.

    """
    data = _get_cv_plot_data(cv_results)
    metrics = data.metrics

    # make subplots (2 plots per row)
    nrows = int(np.ceil(len(metrics) / 2))
    ncols = min(len(metrics), 2)
    fig = tools.make_subplots(
        rows=nrows,
        cols=ncols,
        print_grid=False,
        subplot_titles=tuple(metrics),
        horizontal_spacing=0.15,
        vertical_spacing=0.30 / nrows,
    )

    for i, metric in enumerate(metrics):
        y_hat = []
        se_hat = []
        y_raw = []
        se_raw = []
        for arm in data.in_sample.values():
            y_hat.append(arm.y_hat[metric])
            se_hat.append(arm.se_hat[metric])
            y_raw.append(arm.y[metric])
            se_raw.append(arm.se[metric])
        min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat)
        fig.append_trace(_diagonal_trace(min_, max_),
                         int(np.floor(i / 2)) + 1, i % 2 + 1)
        fig.append_trace(
            _error_scatter_trace(
                # Expected `List[typing.Union[PlotInSampleArm,
                # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous
                # parameter to call `ax.plot.scatter._error_scatter_trace` but
                # got `List[PlotInSampleArm]`.
                # pyre-fixme[6]:
                list(data.in_sample.values()),
                y_axis_var=PlotMetric(metric, True),
                x_axis_var=PlotMetric(metric, False),
                y_axis_label="Predicted",
                x_axis_label="Actual",
                hoverinfo="text",
                show_arm_details_on_hover=show_arm_details_on_hover,
                show_context=show_context,
            ),
            int(np.floor(i / 2)) + 1,
            i % 2 + 1,
        )

    # if odd number of plots, need to manually remove the last blank subplot
    # generated by `tools.make_subplots`
    if len(metrics) % 2 == 1:
        del fig["layout"]["xaxis{}".format(nrows * ncols)]
        del fig["layout"]["yaxis{}".format(nrows * ncols)]

    # allocate 400 px per plot (equal aspect ratio)
    fig["layout"].update(
        title="Cross-Validation",  # What should I replace this with?
        hovermode="closest",
        width=800,
        height=400 * nrows,
        font={"size": 10},
        showlegend=False,
    )

    # update subplot title size and the axis labels
    for i, ant in enumerate(fig["layout"]["annotations"]):
        ant["font"].update(size=12)
        fig["layout"]["xaxis{}".format(i + 1)].update(title="Actual Outcome",
                                                      mirror=True,
                                                      linecolor="black",
                                                      linewidth=0.5)
        fig["layout"]["yaxis{}".format(i + 1)].update(
            title="Predicted Outcome",
            mirror=True,
            linecolor="black",
            linewidth=0.5)

    return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)