def _obs_vs_pred_dropdown_plot( data: PlotData, rel: bool, show_context: bool = False, xlabel: str = "Actual Outcome", ylabel: str = "Predicted Outcome", ) -> Dict[str, Any]: """Plot a dropdown plot of observed vs. predicted values from a model. Args: data: a name tuple storing observed and predicted data from a model. rel: if True, plot metrics relative to the status quo. show_context: Show context on hover. xlabel: Label for x-axis. ylabel: Label for y-axis. """ traces = [] metric_dropdown = [] if rel and data.status_quo_name is not None: if show_context: raise ValueError( "This plot does not support both context and relativization at " "the same time.") # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. status_quo_arm = data.in_sample[data.status_quo_name] else: status_quo_arm = None for i, metric in enumerate(data.metrics): y_raw, se_raw, y_hat, se_hat = _error_scatter_data( # Expected `List[typing.Union[PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous # parameter to call `ax.plot.scatter._error_scatter_data` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(data.in_sample.values()), y_axis_var=PlotMetric(metric, True), x_axis_var=PlotMetric(metric, False), rel=rel, status_quo_arm=status_quo_arm, ) min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw or [], se_hat) traces.append(_diagonal_trace(min_, max_, visible=(i == 0))) traces.append( _error_scatter_trace( # Expected `List[typing.Union[PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st parameter # `arms` to call `ax.plot.scatter._error_scatter_trace` # but got `List[PlotInSampleArm]`. # pyre-fixme[6]: arms=list(data.in_sample.values()), hoverinfo="text", rel=rel, show_arm_details_on_hover=True, show_CI=True, show_context=show_context, status_quo_arm=status_quo_arm, visible=(i == 0), x_axis_label=xlabel, x_axis_var=PlotMetric(metric, False), y_axis_label=ylabel, y_axis_var=PlotMetric(metric, True), )) # only the first two traces are visible (corresponding to first outcome # in dropdown) is_visible = [False] * (len(data.metrics) * 2) is_visible[2 * i] = True is_visible[2 * i + 1] = True # on dropdown change, restyle metric_dropdown.append({ "args": ["visible", is_visible], "label": metric, "method": "restyle" }) updatemenus = [ { "x": 0, "y": 1.125, "yanchor": "top", "xanchor": "left", "buttons": metric_dropdown, }, { "buttons": [ { "args": [{ "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, }], "label": "Yes", "method": "restyle", }, { "args": [{ "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, }], "label": "No", "method": "restyle", }, ], "x": 1.125, "xanchor": "left", "y": 0.8, "yanchor": "middle", }, ] layout = go.Layout( annotations=[{ "showarrow": False, "text": "Show CI", "x": 1.125, "xanchor": "left", "xref": "paper", "y": 0.9, "yanchor": "middle", "yref": "paper", }], xaxis={ "title": xlabel, "zeroline": False, "mirror": True, "linecolor": "black", "linewidth": 0.5, }, yaxis={ "title": ylabel, "zeroline": False, "mirror": True, "linecolor": "black", "linewidth": 0.5, }, showlegend=False, hovermode="closest", updatemenus=updatemenus, width=530, height=500, ) return go.Figure(data=traces, layout=layout)
def tile_cross_validation( cv_results: List[CVResult], show_arm_details_on_hover: bool = True, show_context: bool = True, ) -> AxPlotConfig: """Tile version of CV plots; sorted by 'best fitting' outcomes. Plots are sorted in decreasing order using the p-value of a Fisher exact test statistic. Args: cv_results: cross-validation results. include_measurement_error: if True, include measurement_error metrics in plot. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is True. show_context: if True (default), display context on hover. """ data = _get_cv_plot_data(cv_results) metrics = data.metrics # make subplots (2 plots per row) nrows = int(np.ceil(len(metrics) / 2)) ncols = min(len(metrics), 2) fig = tools.make_subplots( rows=nrows, cols=ncols, print_grid=False, subplot_titles=tuple(metrics), horizontal_spacing=0.15, vertical_spacing=0.30 / nrows, ) for i, metric in enumerate(metrics): y_hat = [] se_hat = [] y_raw = [] se_raw = [] for arm in data.in_sample.values(): y_hat.append(arm.y_hat[metric]) se_hat.append(arm.se_hat[metric]) y_raw.append(arm.y[metric]) se_raw.append(arm.se[metric]) min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat) fig.append_trace(_diagonal_trace(min_, max_), int(np.floor(i / 2)) + 1, i % 2 + 1) fig.append_trace( _error_scatter_trace( # Expected `List[typing.Union[PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous # parameter to call `ax.plot.scatter._error_scatter_trace` but # got `List[PlotInSampleArm]`. # pyre-fixme[6]: list(data.in_sample.values()), y_axis_var=PlotMetric(metric, True), x_axis_var=PlotMetric(metric, False), y_axis_label="Predicted", x_axis_label="Actual", hoverinfo="text", show_arm_details_on_hover=show_arm_details_on_hover, show_context=show_context, ), int(np.floor(i / 2)) + 1, i % 2 + 1, ) # if odd number of plots, need to manually remove the last blank subplot # generated by `tools.make_subplots` if len(metrics) % 2 == 1: del fig["layout"]["xaxis{}".format(nrows * ncols)] del fig["layout"]["yaxis{}".format(nrows * ncols)] # allocate 400 px per plot (equal aspect ratio) fig["layout"].update( title="Cross-Validation", # What should I replace this with? hovermode="closest", width=800, height=400 * nrows, font={"size": 10}, showlegend=False, ) # update subplot title size and the axis labels for i, ant in enumerate(fig["layout"]["annotations"]): ant["font"].update(size=12) fig["layout"]["xaxis{}".format(i + 1)].update(title="Actual Outcome", mirror=True, linecolor="black", linewidth=0.5) fig["layout"]["yaxis{}".format(i + 1)].update( title="Predicted Outcome", mirror=True, linecolor="black", linewidth=0.5) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)