def table_view_plot( experiment: Experiment, data: Data, use_empirical_bayes: bool = True, only_data_frame: bool = False, arm_noun: str = "arm", ): """Table of means and confidence intervals. Table is of the form: +-------+------------+-----------+ | arm | metric_1 | metric_2 | +=======+============+===========+ | 0_0 | mean +- CI | ... | +-------+------------+-----------+ | 0_1 | ... | ... | +-------+------------+-----------+ """ model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson model = model_func(experiment=experiment, data=data) # We don't want to include metrics from a collection, # or the chart will be too big to read easily. # Example: # experiment.metrics = { # 'regular_metric': Metric(), # 'collection_metric: CollectionMetric()', # collection_metric =[metric1, metric2] # } # model.metric_names = [regular_metric, metric1, metric2] # "exploded" out # We want to filter model.metric_names and get rid of metric1, metric2 metric_names = [ metric_name for metric_name in model.metric_names if metric_name in experiment.metrics ] metric_name_to_lower_is_better = { metric_name: experiment.metrics[metric_name].lower_is_better for metric_name in metric_names } plot_data, _, _ = get_plot_data( model=model, generator_runs_dict={}, # pyre-fixme[6]: Expected `Optional[typing.Set[str]]` for 3rd param but got # `List[str]`. metric_names=metric_names, ) if plot_data.status_quo_name: status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name) rel = True else: status_quo_arm = None rel = False records = [] colors = [] records_with_mean = [] records_with_ci = [] for metric_name in metric_names: arm_names, _, ys, ys_se = _error_scatter_data( # pyre-fixme[6]: Expected # `List[typing.Union[ax.plot.base.PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st param but got # `List[ax.plot.base.PlotInSampleArm]`. arms=list(plot_data.in_sample.values()), y_axis_var=PlotMetric(metric_name, pred=True, rel=rel), x_axis_var=None, status_quo_arm=status_quo_arm, ) results_by_arm = list(zip(arm_names, ys, ys_se)) colors.append([ get_color( x=y, ci=Z * y_se, rel=rel, # pyre-fixme[6]: Expected `bool` for 4th param but got # `Optional[bool]`. reverse=metric_name_to_lower_is_better[metric_name], ) for (_, y, y_se) in results_by_arm ]) records.append([ "{:.3f} ± {:.3f}".format(y, Z * y_se) for (_, y, y_se) in results_by_arm ]) records_with_mean.append( {arm_name: y for (arm_name, y, _) in results_by_arm}) records_with_ci.append( {arm_name: Z * y_se for (arm_name, _, y_se) in results_by_arm}) if only_data_frame: return tuple( pd.DataFrame.from_records(records, index=metric_names) for records in [records_with_mean, records_with_ci]) def transpose(m): return [[m[j][i] for j in range(len(m))] for i in range(len(m[0]))] records = [[name.replace(":", " : ") for name in metric_names]] + transpose(records) colors = [["#ffffff"] * len(metric_names)] + transpose(colors) # pyre-fixme[6]: Expected `List[str]` for 1st param but got `List[float]`. header = [f"<b>{x}</b>" for x in [f"{arm_noun}s"] + arm_names] column_widths = [300] + [150] * len(arm_names) trace = go.Table( header={ "values": header, "align": ["left"] }, cells={ "values": records, "align": ["left"], "fill": { "color": colors } }, columnwidth=column_widths, ) layout = go.Layout( width=sum(column_widths), margin=go.layout.Margin(l=0, r=20, b=20, t=20, pad=4), # noqa E741 ) fig = go.Figure(data=[trace], layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def _obs_vs_pred_dropdown_plot( data: PlotData, rel: bool, show_context: bool = False, xlabel: str = "Actual Outcome", ylabel: str = "Predicted Outcome", ) -> Dict[str, Any]: """Plot a dropdown plot of observed vs. predicted values from a model. Args: data: a name tuple storing observed and predicted data from a model. rel: if True, plot metrics relative to the status quo. show_context: Show context on hover. xlabel: Label for x-axis. ylabel: Label for y-axis. """ traces = [] metric_dropdown = [] if rel and data.status_quo_name is not None: if show_context: raise ValueError( "This plot does not support both context and relativization at " "the same time.") # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. status_quo_arm = data.in_sample[data.status_quo_name] else: status_quo_arm = None for i, metric in enumerate(data.metrics): y_raw, se_raw, y_hat, se_hat = _error_scatter_data( # Expected `List[typing.Union[PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous # parameter to call `ax.plot.scatter._error_scatter_data` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(data.in_sample.values()), y_axis_var=PlotMetric(metric, True), x_axis_var=PlotMetric(metric, False), rel=rel, status_quo_arm=status_quo_arm, ) min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw or [], se_hat) traces.append(_diagonal_trace(min_, max_, visible=(i == 0))) traces.append( _error_scatter_trace( # Expected `List[typing.Union[PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st parameter # `arms` to call `ax.plot.scatter._error_scatter_trace` # but got `List[PlotInSampleArm]`. # pyre-fixme[6]: arms=list(data.in_sample.values()), hoverinfo="text", rel=rel, show_arm_details_on_hover=True, show_CI=True, show_context=show_context, status_quo_arm=status_quo_arm, visible=(i == 0), x_axis_label=xlabel, x_axis_var=PlotMetric(metric, False), y_axis_label=ylabel, y_axis_var=PlotMetric(metric, True), )) # only the first two traces are visible (corresponding to first outcome # in dropdown) is_visible = [False] * (len(data.metrics) * 2) is_visible[2 * i] = True is_visible[2 * i + 1] = True # on dropdown change, restyle metric_dropdown.append({ "args": ["visible", is_visible], "label": metric, "method": "restyle" }) updatemenus = [ { "x": 0, "y": 1.125, "yanchor": "top", "xanchor": "left", "buttons": metric_dropdown, }, { "buttons": [ { "args": [{ "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, }], "label": "Yes", "method": "restyle", }, { "args": [{ "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, }], "label": "No", "method": "restyle", }, ], "x": 1.125, "xanchor": "left", "y": 0.8, "yanchor": "middle", }, ] layout = go.Layout( annotations=[{ "showarrow": False, "text": "Show CI", "x": 1.125, "xanchor": "left", "xref": "paper", "y": 0.9, "yanchor": "middle", "yref": "paper", }], xaxis={ "title": xlabel, "zeroline": False, "mirror": True, "linecolor": "black", "linewidth": 0.5, }, yaxis={ "title": ylabel, "zeroline": False, "mirror": True, "linecolor": "black", "linewidth": 0.5, }, showlegend=False, hovermode="closest", updatemenus=updatemenus, width=530, height=500, ) return go.Figure(data=traces, layout=layout)
def table_view_plot( experiment: Experiment, data: Data, use_empirical_bayes: bool = True, only_data_frame: bool = False, ): """Table of means and confidence intervals. Table is of the form: +-------+------------+-----------+ | arm | metric_1 | metric_2 | +=======+============+===========+ | 0_0 | mean +- CI | ... | +-------+------------+-----------+ | 0_1 | ... | ... | +-------+------------+-----------+ """ model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson model = model_func(experiment=experiment, data=data) metric_name_to_lower_is_better = { metric.name: metric.lower_is_better for metric in experiment.metrics.values() } plot_data, _, _ = get_plot_data(model=model, generator_runs_dict={}, metric_names=model.metric_names) if plot_data.status_quo_name: status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name) rel = True else: status_quo_arm = None rel = False results = {} records_with_mean = [] records_with_ci = [] for metric_name in model.metric_names: arms, _, ys, ys_se = _error_scatter_data( arms=list(plot_data.in_sample.values()), y_axis_var=PlotMetric(metric_name, True), x_axis_var=None, rel=rel, status_quo_arm=status_quo_arm, ) # results[metric] will hold a list of tuples, one tuple per arm tuples = list(zip(arms, ys, ys_se)) results[metric_name] = tuples # used if only_data_frame == True records_with_mean.append({arm: y for (arm, y, _) in tuples}) records_with_ci.append({arm: y_se for (arm, _, y_se) in tuples}) if only_data_frame: return tuple( pd.DataFrame.from_records(records, index=model.metric_names).transpose() for records in [records_with_mean, records_with_ci]) # cells and colors are both lists of lists # each top-level list corresponds to a column, # so the first is a list of arms cells = [[f"<b>{x}</b>" for x in arms]] colors = [["#ffffff"] * len(arms)] metric_names = [] for metric_name, list_of_tuples in sorted(results.items()): cells.append([ "{:.3f} ± {:.3f}".format(y, Z * y_se) for (_, y, y_se) in list_of_tuples ]) metric_names.append(metric_name.replace(":", " : ")) color_vec = [] for (_, y, y_se) in list_of_tuples: color_vec.append( get_color( x=y, ci=Z * y_se, rel=rel, reverse=metric_name_to_lower_is_better[metric_name], )) colors.append(color_vec) header = ["arms"] + metric_names header = [f"<b>{x}</b>" for x in header] trace = go.Table( header={ "values": header, "align": ["left"] }, cells={ "values": cells, "align": ["left"], "fill": { "color": colors } }, ) layout = go.Layout( height=min([400, len(arms) * 20 + 200]), width=175 * len(header), margin=go.Margin(l=0, r=20, b=20, t=20, pad=4), # noqa E741 ) fig = go.Figure(data=[trace], layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)