def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig: """ Calculates and plots the marginal effects -- the effect of changing one factor away from the randomized distribution of the experiment and fixing it at a particular level. Args: model: Model to use for estimating effects metric: The metric for which to plot marginal effects. Returns: AxPlotConfig of the marginal effects """ plot_data, _, _ = get_plot_data(model, {}, {metric}) arm_dfs = [] for arm in plot_data.in_sample.values(): arm_df = pd.DataFrame(arm.parameters, index=[arm.name]) arm_df["mean"] = arm.y_hat[metric] arm_df["sem"] = arm.se_hat[metric] arm_dfs.append(arm_df) effect_table = marginal_effects(pd.concat(arm_dfs, 0)) varnames = effect_table["Name"].unique() data: List[Any] = [] for varname in varnames: var_df = effect_table[effect_table["Name"] == varname] data += [ # pyre-ignore[16] go.Bar( x=var_df["Level"], y=var_df["Beta"], error_y={ "type": "data", "array": var_df["SE"] }, name=varname, ) ] fig = tools.make_subplots( cols=len(varnames), rows=1, subplot_titles=list(varnames), print_grid=False, shared_yaxes=True, ) for idx, item in enumerate(data): fig.append_trace(item, 1, idx + 1) fig.layout.showlegend = False # fig.layout.margin = go.Margin(l=2, r=2) fig.layout.title = "Marginal Effects by Factor" fig.layout.yaxis = { "title": "% better than experiment average", "hoverformat": ".{}f".format(DECIMALS), } return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def _get_contour_predictions( model: ModelBridge, x_param_name: str, y_param_name: str, metric: str, generator_runs_dict: TNullableGeneratorRunsDict, density: int, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, ) -> ContourPredictions: """ slice_values is a dictionary {param_name: value} for the parameters that are being sliced on. """ x_param = get_range_parameter(model, x_param_name) y_param = get_range_parameter(model, y_param_name) plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric}, fixed_features=fixed_features) grid_x = get_grid_for_parameter(x_param, density) grid_y = get_grid_for_parameter(y_param, density) scales = {"x": x_param.log_scale, "y": y_param.log_scale} grid2_x, grid2_y = np.meshgrid(grid_x, grid_y) grid2_x = grid2_x.flatten() grid2_y = grid2_y.flatten() if fixed_features is not None: slice_values = fixed_features.parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values) param_grid_obsf = [] for i in range(density**2): predf = deepcopy(fixed_features) predf.parameters = fixed_values.copy() predf.parameters[x_param_name] = grid2_x[i] predf.parameters[y_param_name] = grid2_y[i] param_grid_obsf.append(predf) mu, cov = model.predict(param_grid_obsf) f_plt = mu[metric] sd_plt = np.sqrt(cov[metric][metric]) # pyre-fixme[7]: Expected `Tuple[PlotData, np.ndarray, np.ndarray, np.ndarray, # np.ndarray, Dict[str, bool]]` but got `Tuple[PlotData, typing.List[float], # typing.Any, np.ndarray, np.ndarray, Dict[str, bool]]`. return plot_data, f_plt, sd_plt, grid_x, grid_y, scales
def _get_contour_predictions( model: ModelBridge, x_param_name: str, y_param_name: str, metric: str, generator_runs_dict: TNullableGeneratorRunsDict, density: int, slice_values: Optional[Dict[str, Any]] = None, ) -> ContourPredictions: """ slice_values is a dictionary {param_name: value} for the parameters that are being sliced on. """ x_param = get_range_parameter(model, x_param_name) y_param = get_range_parameter(model, y_param_name) plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric}) grid_x = get_grid_for_parameter(x_param, density) grid_y = get_grid_for_parameter(y_param, density) scales = {"x": x_param.log_scale, "y": y_param.log_scale} grid2_x, grid2_y = np.meshgrid(grid_x, grid_y) grid2_x = grid2_x.flatten() grid2_y = grid2_y.flatten() fixed_values = get_fixed_values(model, slice_values) param_grid_obsf = [] for i in range(density ** 2): parameters = fixed_values.copy() parameters[x_param_name] = grid2_x[i] parameters[y_param_name] = grid2_y[i] param_grid_obsf.append(ObservationFeatures(parameters)) mu, cov = model.predict(param_grid_obsf) f_plt = mu[metric] sd_plt = np.sqrt(cov[metric][metric]) return plot_data, f_plt, sd_plt, grid_x, grid_y, scales
def _get_slice_predictions( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> SlicePredictions: """Computes slice prediction configuration values for a single metric name. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: Configruation values for AxPlotConfig. """ if generator_runs_dict is None: generator_runs_dict = {} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data, raw_data, cond_name_to_parameters = get_plot_data( model=model, generator_runs_dict=generator_runs_dict, metric_names={metric_name}, fixed_features=fixed_features, ) if fixed_features is not None: slice_values = fixed_features.parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values, trial_index) prediction_features = [] for x in grid: predf = deepcopy(fixed_features) predf.parameters = fixed_values.copy() predf.parameters[param_name] = x prediction_features.append(predf) f, cov = model.predict(prediction_features) f_plt = f[metric_name] sd_plt = np.sqrt(cov[metric_name][metric_name]) # pyre-fixme[7]: Expected `Tuple[PlotData, List[Dict[str, Union[float, str]]], # List[float], np.ndarray, np.ndarray, str, str, bool, Dict[str, Union[None, bool, # float, int, str]], np.ndarray, bool]` but got `Tuple[PlotData, Dict[str, # Dict[str, Union[None, bool, float, int, str]]], List[float], List[Dict[str, # Union[float, str]]], np.ndarray, str, str, bool, Dict[str, Union[None, bool, # float, int, str]], typing.Any, bool]`. return ( plot_data, cond_name_to_parameters, f_plt, raw_data, grid, metric_name, param_name, relative, fixed_values, sd_plt, parameter.log_scale, )
def table_view_plot( experiment: Experiment, data: Data, use_empirical_bayes: bool = True, only_data_frame: bool = False, arm_noun: str = "arm", ): """Table of means and confidence intervals. Table is of the form: +-------+------------+-----------+ | arm | metric_1 | metric_2 | +=======+============+===========+ | 0_0 | mean +- CI | ... | +-------+------------+-----------+ | 0_1 | ... | ... | +-------+------------+-----------+ """ model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson model = model_func(experiment=experiment, data=data) # We don't want to include metrics from a collection, # or the chart will be too big to read easily. # Example: # experiment.metrics = { # 'regular_metric': Metric(), # 'collection_metric: CollectionMetric()', # collection_metric =[metric1, metric2] # } # model.metric_names = [regular_metric, metric1, metric2] # "exploded" out # We want to filter model.metric_names and get rid of metric1, metric2 metric_names = [ metric_name for metric_name in model.metric_names if metric_name in experiment.metrics ] metric_name_to_lower_is_better = { metric_name: experiment.metrics[metric_name].lower_is_better for metric_name in metric_names } plot_data, _, _ = get_plot_data( model=model, generator_runs_dict={}, # pyre-fixme[6]: Expected `Optional[typing.Set[str]]` for 3rd param but got # `List[str]`. metric_names=metric_names, ) if plot_data.status_quo_name: status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name) rel = True else: status_quo_arm = None rel = False records = [] colors = [] records_with_mean = [] records_with_ci = [] for metric_name in metric_names: arm_names, _, ys, ys_se = _error_scatter_data( # pyre-fixme[6]: Expected # `List[typing.Union[ax.plot.base.PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st param but got # `List[ax.plot.base.PlotInSampleArm]`. arms=list(plot_data.in_sample.values()), y_axis_var=PlotMetric(metric_name, pred=True, rel=rel), x_axis_var=None, status_quo_arm=status_quo_arm, ) results_by_arm = list(zip(arm_names, ys, ys_se)) colors.append([ get_color( x=y, ci=Z * y_se, rel=rel, # pyre-fixme[6]: Expected `bool` for 4th param but got # `Optional[bool]`. reverse=metric_name_to_lower_is_better[metric_name], ) for (_, y, y_se) in results_by_arm ]) records.append([ "{:.3f} ± {:.3f}".format(y, Z * y_se) for (_, y, y_se) in results_by_arm ]) records_with_mean.append( {arm_name: y for (arm_name, y, _) in results_by_arm}) records_with_ci.append( {arm_name: Z * y_se for (arm_name, _, y_se) in results_by_arm}) if only_data_frame: return tuple( pd.DataFrame.from_records(records, index=metric_names) for records in [records_with_mean, records_with_ci]) def transpose(m): return [[m[j][i] for j in range(len(m))] for i in range(len(m[0]))] records = [[name.replace(":", " : ") for name in metric_names]] + transpose(records) colors = [["#ffffff"] * len(metric_names)] + transpose(colors) # pyre-fixme[6]: Expected `List[str]` for 1st param but got `List[float]`. header = [f"<b>{x}</b>" for x in [f"{arm_noun}s"] + arm_names] column_widths = [300] + [150] * len(arm_names) trace = go.Table( header={ "values": header, "align": ["left"] }, cells={ "values": records, "align": ["left"], "fill": { "color": colors } }, columnwidth=column_widths, ) layout = go.Layout( width=sum(column_widths), margin=go.layout.Margin(l=0, r=20, b=20, t=20, pad=4), # noqa E741 ) fig = go.Figure(data=[trace], layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_slice( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, ) -> AxPlotConfig: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. """ if generator_runs_dict is None: generator_runs_dict = {} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data, raw_data, cond_name_to_parameters = get_plot_data( model=model, generator_runs_dict=generator_runs_dict, metric_names={metric_name} ) fixed_values = get_fixed_values(model, slice_values) prediction_features = [] for x in grid: parameters = fixed_values.copy() parameters[param_name] = x # Here we assume context is None prediction_features.append(ObservationFeatures(parameters=parameters)) f, cov = model.predict(prediction_features) f_plt = f[metric_name] sd_plt = np.sqrt(cov[metric_name][metric_name]) config = { "arm_data": plot_data, "arm_name_to_parameters": cond_name_to_parameters, "f": f_plt, "fit_data": raw_data, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt, "is_log": parameter.log_scale, } return AxPlotConfig(config, plot_type=AxPlotTypes.SLICE)
def _single_metric_traces( model: ModelBridge, metric: str, generator_runs_dict: TNullableGeneratorRunsDict, rel: bool, show_arm_details_on_hover: bool = True, showlegend: bool = True, show_CI: bool = True, arm_noun: str = "arm", fixed_features: Optional[ObservationFeatures] = None, ) -> Traces: """Plot scatterplots with errors for a single metric (y-axis). Arms are plotted on the x-axis. Args: model: model to draw predictions from. metric: name of metric to plot. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, plot relative predictions. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is True. show_legend: if True, show legend for trace. show_CI: if True, render confidence intervals. arm_noun: noun to use instead of "arm" (e.g. group) fixed_features: Fixed features to use when making model predictions. """ plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric}, fixed_features=fixed_features) status_quo_arm = ( None if plot_data.status_quo_name is None # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. else plot_data.in_sample.get(plot_data.status_quo_name)) traces = [ _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]` # for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=None, y_axis_var=PlotMetric(metric, pred=True, rel=rel), status_quo_arm=status_quo_arm, legendgroup="In-sample", showlegend=showlegend, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, ) ] # Candidates for i, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1): traces.append( _error_scatter_trace( list(cand_arms.values()), x_axis_var=None, y_axis_var=PlotMetric(metric, pred=True, rel=rel), status_quo_arm=status_quo_arm, name=generator_run_name, color=DISCRETE_COLOR_SCALE[i], legendgroup=generator_run_name, showlegend=showlegend, show_arm_details_on_hover=show_arm_details_on_hover, show_CI=show_CI, arm_noun=arm_noun, )) return traces
def lattice_multiple_metrics( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, ) -> AxPlotConfig: """Plot raw values or predictions of combinations of two metrics for arms. Args: model: model to draw predictions from. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is False. """ metrics = model.metric_names fig = tools.make_subplots( rows=len(metrics), cols=len(metrics), print_grid=False, shared_xaxes=False, shared_yaxes=False, ) plot_data, _, _ = get_plot_data( model, generator_runs_dict if generator_runs_dict is not None else {}, metrics) status_quo_arm = ( None if plot_data.status_quo_name is None # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. else plot_data.in_sample.get(plot_data.status_quo_name)) # iterate over all combinations of metrics and generate scatter traces for i, o1 in enumerate(metrics, start=1): for j, o2 in enumerate(metrics, start=1): if o1 != o2: # in-sample observed and predicted obs_insample_trace = _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(o1, pred=False, rel=rel), y_axis_var=PlotMetric(o2, pred=False, rel=rel), status_quo_arm=status_quo_arm, showlegend=(i == 1 and j == 2), legendgroup="In-sample", visible=False, show_arm_details_on_hover=show_arm_details_on_hover, ) predicted_insample_trace = _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(o1, pred=True, rel=rel), y_axis_var=PlotMetric(o2, pred=True, rel=rel), status_quo_arm=status_quo_arm, legendgroup="In-sample", showlegend=(i == 1 and j == 2), visible=True, show_arm_details_on_hover=show_arm_details_on_hover, ) fig.append_trace(obs_insample_trace, j, i) fig.append_trace(predicted_insample_trace, j, i) # iterate over models here for k, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1): fig.append_trace( _error_scatter_trace( list(cand_arms.values()), x_axis_var=PlotMetric(o1, pred=True, rel=rel), y_axis_var=PlotMetric(o2, pred=True, rel=rel), status_quo_arm=status_quo_arm, name=generator_run_name, color=DISCRETE_COLOR_SCALE[k], showlegend=(i == 1 and j == 2), legendgroup=generator_run_name, show_arm_details_on_hover=show_arm_details_on_hover, ), j, i, ) else: # if diagonal is set to True, add box plots fig.append_trace( go.Box( y=[arm.y[o1] for arm in plot_data.in_sample.values()], name=None, marker={"color": rgba(COLORS.STEELBLUE.value)}, showlegend=False, legendgroup="In-sample", visible=False, hoverinfo="none", ), j, i, ) fig.append_trace( go.Box( y=[ arm.y_hat[o1] for arm in plot_data.in_sample.values() ], name=None, marker={"color": rgba(COLORS.STEELBLUE.value)}, showlegend=False, legendgroup="In-sample", hoverinfo="none", ), j, i, ) for k, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1): fig.append_trace( go.Box( y=[arm.y_hat[o1] for arm in cand_arms.values()], name=None, marker={"color": rgba(DISCRETE_COLOR_SCALE[k])}, showlegend=False, legendgroup=generator_run_name, hoverinfo="none", ), j, i, ) fig["layout"].update( height=800, width=960, font={"size": 10}, hovermode="closest", legend={ "orientation": "h", "x": 0, "y": 1.05, "xanchor": "left", "yanchor": "middle", }, updatemenus=[ { "x": 0.35, "y": 1.08, "xanchor": "left", "yanchor": "middle", "buttons": [ { "args": [{ "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, }], "label": "No", "method": "restyle", }, { "args": [{ "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, }], "label": "Yes", "method": "restyle", }, ], }, { "x": 0.1, "y": 1.08, "xanchor": "left", "yanchor": "middle", "buttons": [ { "args": [{ "visible": (([False, True] + [True] * len(plot_data.out_of_sample or {})) * (len(metrics)**2)) }], "label": "Modeled", "method": "restyle", }, { "args": [{ "visible": (([True, False] + [False] * len(plot_data.out_of_sample or {})) * (len(metrics)**2)) }], "label": "In-sample", "method": "restyle", }, ], }, ], annotations=[ { "x": 0.02, "y": 1.1, "xref": "paper", "yref": "paper", "text": "Type", "showarrow": False, "yanchor": "middle", "xanchor": "left", }, { "x": 0.30, "y": 1.1, "xref": "paper", "yref": "paper", "text": "Show CI", "showarrow": False, "yanchor": "middle", "xanchor": "left", }, ], ) # add metric names to axes - add to each subplot if boxplots on the # diagonal and axes are not shared; else, add to the leftmost y-axes # and bottom x-axes. for i, o in enumerate(metrics): pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1 pos_y = 1 + (len(metrics) * i) fig["layout"]["xaxis{}".format(pos_x)].update(title=_wrap_metric(o), titlefont={"size": 10}) fig["layout"]["yaxis{}".format(pos_y)].update(title=_wrap_metric(o), titlefont={"size": 10}) # do not put x-axis ticks for boxplots boxplot_xaxes = [] for trace in fig["data"]: if trace["type"] == "box": # stores the xaxes which correspond to boxplot subplots # since we use xaxis1, xaxis2, etc, in plotly.py boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:])) else: # clear all error bars since default is no CI trace["error_x"].update(width=0, thickness=0) trace["error_y"].update(width=0, thickness=0) for xaxis in boxplot_xaxes: fig["layout"][xaxis]["showticklabels"] = False return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def _multiple_metric_traces( model: ModelBridge, metric_x: str, metric_y: str, generator_runs_dict: TNullableGeneratorRunsDict, rel_x: bool, rel_y: bool, fixed_features: Optional[ObservationFeatures] = None, ) -> Traces: """Plot traces for multiple metrics given a model and metrics. Args: model: model to draw predictions from. metric_x: metric to plot on the x-axis. metric_y: metric to plot on the y-axis. generator_runs_dict: a mapping from generator run name to generator run. rel_x: if True, use relative effects on metric_x. rel_y: if True, use relative effects on metric_y. fixed_features: Fixed features to use when making model predictions. """ plot_data, _, _ = get_plot_data( model, generator_runs_dict if generator_runs_dict is not None else {}, {metric_x, metric_y}, fixed_features=fixed_features, ) status_quo_arm = ( None if plot_data.status_quo_name is None # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. else plot_data.in_sample.get(plot_data.status_quo_name)) traces = [ _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]` # for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(metric_x, pred=False, rel=rel_x), y_axis_var=PlotMetric(metric_y, pred=False, rel=rel_y), status_quo_arm=status_quo_arm, visible=False, ), _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, PlotOutOfSampleArm]]` # for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x), y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y), status_quo_arm=status_quo_arm, visible=True, ), ] for i, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1): traces.append( _error_scatter_trace( list(cand_arms.values()), x_axis_var=PlotMetric(metric_x, pred=True, rel=rel_x), y_axis_var=PlotMetric(metric_y, pred=True, rel=rel_y), status_quo_arm=status_quo_arm, name=generator_run_name, color=DISCRETE_COLOR_SCALE[i], )) return traces
def interact_contour( model: ModelBridge, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, fixed_features: Optional[ObservationFeatures] = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. """ range_parameters = get_range_parameters(model) plot_data, _, _ = get_plot_data(model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features) # TODO T38563759: Sort parameters by feature importances param_names = [parameter.name for parameter in range_parameters] is_log_dict: Dict[str, bool] = {} grid_dict: Dict[str, np.ndarray] = {} for parameter in range_parameters: is_log_dict[parameter.name] = parameter.log_scale grid_dict[parameter.name] = get_grid_for_parameter(parameter, density) # pyre: f_dict is declared to have type `Dict[str, Dict[str, np.ndarray]]` # pyre-fixme[9]: but is used as type `Dict[str, Dict[str, typing.List[]]]`. f_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } # pyre: sd_dict is declared to have type `Dict[str, Dict[str, np. # pyre: ndarray]]` but is used as type `Dict[str, Dict[str, typing. # pyre-fixme[9]: List[]]]`. sd_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } for param1 in param_names: for param2 in param_names: _, f_plt, sd_plt, _, _, _ = _get_contour_predictions( model=model, x_param_name=param1, y_param_name=param2, metric=metric_name, generator_runs_dict=generator_runs_dict, density=density, slice_values=slice_values, fixed_features=fixed_features, ) f_dict[param1][param2] = f_plt sd_dict[param1][param2] = sd_plt config = { "arm_data": plot_data, "blue_scale": BLUE_SCALE, "density": density, "f_dict": f_dict, "green_scale": GREEN_SCALE, "green_pink_scale": GREEN_PINK_SCALE, "grid_dict": grid_dict, "lower_is_better": lower_is_better, "metric": metric_name, "rel": relative, "sd_dict": sd_dict, "is_log_dict": is_log_dict, "param_names": param_names, } return AxPlotConfig(config, plot_type=AxPlotTypes.INTERACT_CONTOUR)
def interact_contour( model: ModelBridge, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, fixed_features: Optional[ObservationFeatures] = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. """ range_parameters = get_range_parameters(model) plot_data, _, _ = get_plot_data( model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features ) # TODO T38563759: Sort parameters by feature importances param_names = [parameter.name for parameter in range_parameters] is_log_dict: Dict[str, bool] = {} grid_dict: Dict[str, np.ndarray] = {} for parameter in range_parameters: is_log_dict[parameter.name] = parameter.log_scale grid_dict[parameter.name] = get_grid_for_parameter(parameter, density) f_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } sd_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } for param1 in param_names: for param2 in param_names: _, f_plt, sd_plt, _, _, _ = _get_contour_predictions( model=model, x_param_name=param1, y_param_name=param2, metric=metric_name, generator_runs_dict=generator_runs_dict, density=density, slice_values=slice_values, fixed_features=fixed_features, ) f_dict[param1][param2] = f_plt sd_dict[param1][param2] = sd_plt config = { "arm_data": plot_data, "blue_scale": BLUE_SCALE, "density": density, "f_dict": f_dict, "green_scale": GREEN_SCALE, "green_pink_scale": GREEN_PINK_SCALE, "grid_dict": grid_dict, "lower_is_better": lower_is_better, "metric": metric_name, "rel": relative, "sd_dict": sd_dict, "is_log_dict": is_log_dict, "param_names": param_names, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] density = config["density"] grid_dict = config["grid_dict"] f_dict = config["f_dict"] lower_is_better = config["lower_is_better"] metric = config["metric"] rel = config["rel"] sd_dict = config["sd_dict"] is_log_dict = config["is_log_dict"] param_names = config["param_names"] green_scale = config["green_scale"] green_pink_scale = config["green_pink_scale"] blue_scale = config["blue_scale"] CONTOUR_CONFIG = { "autocolorscale": False, "autocontour": True, "contours": {"coloring": "heatmap"}, "hoverinfo": "x+y+z", "ncontours": int(density / 2), "type": "contour", } if rel: f_scale = reversed(green_pink_scale) if lower_is_better else green_pink_scale else: f_scale = green_scale f_contour_trace_base = { "colorbar": { "len": 0.875, "x": 0.45, "y": 0.5, "ticksuffix": "%" if rel else "", "tickfont": {"size": 8}, }, "colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)], "xaxis": "x", "yaxis": "y", # zmax and zmin are ignored if zauto is true "zauto": not rel, } sd_contour_trace_base = { "colorbar": { "len": 0.875, "x": 1, "y": 0.5, "ticksuffix": "%" if rel else "", "tickfont": {"size": 8}, }, "colorscale": [ (i / (len(blue_scale) - 1), rgb(v)) for i, v in enumerate(blue_scale) ], "xaxis": "x2", "yaxis": "y2", } f_contour_trace_base.update(CONTOUR_CONFIG) sd_contour_trace_base.update(CONTOUR_CONFIG) insample_param_values = {} for param_name in param_names: insample_param_values[param_name] = [] for arm_name in arm_data["in_sample"].keys(): insample_param_values[param_name].append( arm_data["in_sample"][arm_name]["parameters"][param_name] ) insample_arm_text = list(arm_data["in_sample"].keys()) out_of_sample_param_values = {} for param_name in param_names: out_of_sample_param_values[param_name] = {} for generator_run_name in arm_data["out_of_sample"].keys(): out_of_sample_param_values[param_name][generator_run_name] = [] for arm_name in arm_data["out_of_sample"][generator_run_name].keys(): out_of_sample_param_values[param_name][generator_run_name].append( arm_data["out_of_sample"][generator_run_name][arm_name][ "parameters" ][param_name] ) out_of_sample_arm_text = {} for generator_run_name in arm_data["out_of_sample"].keys(): out_of_sample_arm_text[generator_run_name] = [ "<em>Candidate " + arm_name + "</em>" for arm_name in arm_data["out_of_sample"][generator_run_name].keys() ] # Number of traces for each pair of parameters trace_cnt = 4 + (len(arm_data["out_of_sample"]) * 2) xbuttons = [] ybuttons = [] for xvar in param_names: xbutton_data_args = {"x": [], "y": [], "z": []} for yvar in param_names: res = relativize_data( f_dict[xvar][yvar], sd_dict[xvar][yvar], rel, arm_data, metric ) f_final = res[0] sd_final = res[1] # transform to nested array f_plt = [] for ind in range(0, len(f_final), density): f_plt.append(f_final[ind : ind + density]) sd_plt = [] for ind in range(0, len(sd_final), density): sd_plt.append(sd_final[ind : ind + density]) # grid + in-sample xbutton_data_args["x"] += [ grid_dict[xvar], grid_dict[xvar], insample_param_values[xvar], insample_param_values[xvar], ] xbutton_data_args["y"] += [ grid_dict[yvar], grid_dict[yvar], insample_param_values[yvar], insample_param_values[yvar], ] xbutton_data_args["z"] = xbutton_data_args["z"] + [f_plt, sd_plt, [], []] for generator_run_name in out_of_sample_param_values[xvar]: generator_run_x_vals = out_of_sample_param_values[xvar][ generator_run_name ] xbutton_data_args["x"] += [generator_run_x_vals] * 2 for generator_run_name in out_of_sample_param_values[yvar]: generator_run_y_vals = out_of_sample_param_values[yvar][ generator_run_name ] xbutton_data_args["y"] += [generator_run_y_vals] * 2 xbutton_data_args["z"] += [[]] * 2 xbutton_args = [ xbutton_data_args, { "xaxis.title": short_name(xvar), "xaxis2.title": short_name(xvar), "xaxis.range": axis_range(grid_dict[xvar], is_log_dict[xvar]), "xaxis2.range": axis_range(grid_dict[xvar], is_log_dict[xvar]), }, ] xbuttons.append({"args": xbutton_args, "label": xvar, "method": "update"}) # No y button for first param so initial value is sane for y_idx in range(1, len(param_names)): visible = [False] * (len(param_names) * trace_cnt) for i in range(y_idx * trace_cnt, (y_idx + 1) * trace_cnt): visible[i] = True y_param = param_names[y_idx] ybuttons.append( { "args": [ {"visible": visible}, { "yaxis.title": short_name(y_param), "yaxis.range": axis_range( grid_dict[y_param], is_log_dict[y_param] ), "yaxis2.range": axis_range( grid_dict[y_param], is_log_dict[y_param] ), }, ], "label": param_names[y_idx], "method": "update", } ) # calculate max of abs(outcome), used for colorscale # TODO(T37079623) Make this work for relative outcomes # let f_absmax = Math.max(Math.abs(Math.min(...f_final)), Math.max(...f_final)) traces = [] xvar = param_names[0] base_in_sample_arm_config = None # start symbol at 2 for out-of-sample candidate markers i = 2 for yvar_idx, yvar in enumerate(param_names): cur_visible = yvar_idx == 1 f_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx] sd_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx + 1] # create traces f_trace = { "x": grid_dict[xvar], "y": grid_dict[yvar], "z": f_start, "visible": cur_visible, } for key in f_contour_trace_base.keys(): f_trace[key] = f_contour_trace_base[key] sd_trace = { "x": grid_dict[xvar], "y": grid_dict[yvar], "z": sd_start, "visible": cur_visible, } for key in sd_contour_trace_base.keys(): sd_trace[key] = sd_contour_trace_base[key] f_in_sample_arm_trace = {"xaxis": "x", "yaxis": "y"} sd_in_sample_arm_trace = {"showlegend": False, "xaxis": "x2", "yaxis": "y2"} base_in_sample_arm_config = { "hoverinfo": "text", "legendgroup": "In-sample", "marker": {"color": "black", "symbol": 1, "opacity": 0.5}, "mode": "markers", "name": "In-sample", "text": insample_arm_text, "type": "scatter", "visible": cur_visible, "x": insample_param_values[xvar], "y": insample_param_values[yvar], } for key in base_in_sample_arm_config.keys(): f_in_sample_arm_trace[key] = base_in_sample_arm_config[key] sd_in_sample_arm_trace[key] = base_in_sample_arm_config[key] traces += [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace] # iterate over out-of-sample arms for generator_run_name in arm_data["out_of_sample"].keys(): traces.append( { "hoverinfo": "text", "legendgroup": generator_run_name, "marker": {"color": "black", "symbol": i, "opacity": 0.5}, "mode": "markers", "name": generator_run_name, "text": out_of_sample_arm_text[generator_run_name], "type": "scatter", "xaxis": "x", "x": out_of_sample_param_values[xvar][generator_run_name], "yaxis": "y", "y": out_of_sample_param_values[yvar][generator_run_name], "visible": cur_visible, } ) traces.append( { "hoverinfo": "text", "legendgroup": generator_run_name, "marker": {"color": "black", "symbol": i, "opacity": 0.5}, "mode": "markers", "name": "In-sample", "showlegend": False, "text": out_of_sample_arm_text[generator_run_name], "type": "scatter", "x": out_of_sample_param_values[xvar][generator_run_name], "xaxis": "x2", "y": out_of_sample_param_values[yvar][generator_run_name], "yaxis": "y2", "visible": cur_visible, } ) i += 1 xrange = axis_range(grid_dict[xvar], is_log_dict[xvar]) yrange = axis_range(grid_dict[yvar], is_log_dict[yvar]) xtype = "log" if is_log_dict[xvar] else "linear" ytype = "log" if is_log_dict[yvar] else "linear" layout = { "annotations": [ { "font": {"size": 14}, "showarrow": False, "text": "Mean", "x": 0.25, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper", }, { "font": {"size": 14}, "showarrow": False, "text": "Standard Error", "x": 0.8, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper", }, { "x": 0.26, "y": -0.26, "xref": "paper", "yref": "paper", "text": "x-param:", "showarrow": False, "yanchor": "top", "xanchor": "left", }, { "x": 0.26, "y": -0.4, "xref": "paper", "yref": "paper", "text": "y-param:", "showarrow": False, "yanchor": "top", "xanchor": "left", }, ], "updatemenus": [ { "x": 0.35, "y": -0.29, "buttons": xbuttons, "xanchor": "left", "yanchor": "middle", "direction": "up", }, { "x": 0.35, "y": -0.43, "buttons": ybuttons, "xanchor": "left", "yanchor": "middle", "direction": "up", }, ], "autosize": False, "height": 450, "hovermode": "closest", "legend": {"orientation": "v", "x": 0, "y": -0.2, "yanchor": "top"}, "margin": {"b": 100, "l": 35, "pad": 0, "r": 35, "t": 35}, "width": 950, "xaxis": { "anchor": "y", "autorange": False, "domain": [0.05, 0.45], "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": short_name(xvar), "type": xtype, }, "xaxis2": { "anchor": "y2", "autorange": False, "domain": [0.6, 1], "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": short_name(xvar), "type": xtype, }, "yaxis": { "anchor": "x", "autorange": False, "domain": [0, 1], "exponentformat": "e", "range": yrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": short_name(yvar), "type": ytype, }, "yaxis2": { "anchor": "x2", "autorange": False, "domain": [0, 1], "exponentformat": "e", "range": yrange, "tickfont": {"size": 11}, "tickmode": "auto", "type": ytype, }, } fig = go.Figure(data=traces, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def table_view_plot( experiment: Experiment, data: Data, use_empirical_bayes: bool = True, only_data_frame: bool = False, ): """Table of means and confidence intervals. Table is of the form: +-------+------------+-----------+ | arm | metric_1 | metric_2 | +=======+============+===========+ | 0_0 | mean +- CI | ... | +-------+------------+-----------+ | 0_1 | ... | ... | +-------+------------+-----------+ """ model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson model = model_func(experiment=experiment, data=data) metric_name_to_lower_is_better = { metric.name: metric.lower_is_better for metric in experiment.metrics.values() } plot_data, _, _ = get_plot_data(model=model, generator_runs_dict={}, metric_names=model.metric_names) if plot_data.status_quo_name: status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name) rel = True else: status_quo_arm = None rel = False results = {} records_with_mean = [] records_with_ci = [] for metric_name in model.metric_names: arms, _, ys, ys_se = _error_scatter_data( arms=list(plot_data.in_sample.values()), y_axis_var=PlotMetric(metric_name, True), x_axis_var=None, rel=rel, status_quo_arm=status_quo_arm, ) # results[metric] will hold a list of tuples, one tuple per arm tuples = list(zip(arms, ys, ys_se)) results[metric_name] = tuples # used if only_data_frame == True records_with_mean.append({arm: y for (arm, y, _) in tuples}) records_with_ci.append({arm: y_se for (arm, _, y_se) in tuples}) if only_data_frame: return tuple( pd.DataFrame.from_records(records, index=model.metric_names).transpose() for records in [records_with_mean, records_with_ci]) # cells and colors are both lists of lists # each top-level list corresponds to a column, # so the first is a list of arms cells = [[f"<b>{x}</b>" for x in arms]] colors = [["#ffffff"] * len(arms)] metric_names = [] for metric_name, list_of_tuples in sorted(results.items()): cells.append([ "{:.3f} ± {:.3f}".format(y, Z * y_se) for (_, y, y_se) in list_of_tuples ]) metric_names.append(metric_name.replace(":", " : ")) color_vec = [] for (_, y, y_se) in list_of_tuples: color_vec.append( get_color( x=y, ci=Z * y_se, rel=rel, reverse=metric_name_to_lower_is_better[metric_name], )) colors.append(color_vec) header = ["arms"] + metric_names header = [f"<b>{x}</b>" for x in header] trace = go.Table( header={ "values": header, "align": ["left"] }, cells={ "values": cells, "align": ["left"], "fill": { "color": colors } }, ) layout = go.Layout( height=min([400, len(arms) * 20 + 200]), width=175 * len(header), margin=go.Margin(l=0, r=20, b=20, t=20, pad=4), # noqa E741 ) fig = go.Figure(data=[trace], layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def interact_contour_plotly( model: ModelBridge, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> go.Figure: """Create interactive plot with predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: interactive plot of objective vs. parameters """ # NOTE: This implements a hack to allow Plotly to specify two parameters # simultaneously. It is not possible within Plotly to specify a third, # so `metric_name` must be specified and cannot be selected via dropdown # by the user. if trial_index is not None: if slice_values is None: slice_values = {} slice_values["TRIAL_PARAM"] = str(trial_index) range_parameters = get_range_parameters(model) plot_data, _, _ = get_plot_data( model, generator_runs_dict or {}, {metric_name}, fixed_features=fixed_features ) # TODO T38563759: Sort parameters by feature importances param_names = [parameter.name for parameter in range_parameters] is_log_dict: Dict[str, bool] = {} grid_dict: Dict[str, np.ndarray] = {} for parameter in range_parameters: is_log_dict[parameter.name] = parameter.log_scale grid_dict[parameter.name] = get_grid_for_parameter(parameter, density) # Populate `f_dict` (the predicted expectation value of `metric_name`) and # `sd_dict` (the predicted SEM), each of which represents a 2D array of plots # where each parameter can be assigned to each of the x or y axes. # pyre-fixme[9]: f_dict has type `Dict[str, Dict[str, np.ndarray]]`; used as # `Dict[str, Dict[str, typing.List[Variable[_T]]]]`. f_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } # pyre-fixme[9]: sd_dict has type `Dict[str, Dict[str, np.ndarray]]`; used as # `Dict[str, Dict[str, typing.List[Variable[_T]]]]`. sd_dict: Dict[str, Dict[str, np.ndarray]] = { param1: {param2: [] for param2 in param_names} for param1 in param_names } for param1 in param_names: for param2 in param_names: _, f_plt, sd_plt, _, _, _ = _get_contour_predictions( model=model, x_param_name=param1, y_param_name=param2, metric=metric_name, generator_runs_dict=generator_runs_dict, density=density, slice_values=slice_values, fixed_features=fixed_features, ) f_dict[param1][param2] = f_plt sd_dict[param1][param2] = sd_plt # Set plotting defaults for all subplots config = { "arm_data": plot_data, "blue_scale": BLUE_SCALE, "density": density, "f_dict": f_dict, "green_scale": GREEN_SCALE, "green_pink_scale": GREEN_PINK_SCALE, "grid_dict": grid_dict, "lower_is_better": lower_is_better, "metric": metric_name, "rel": relative, "sd_dict": sd_dict, "is_log_dict": is_log_dict, "param_names": param_names, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] density = config["density"] grid_dict = config["grid_dict"] f_dict = config["f_dict"] lower_is_better = config["lower_is_better"] metric = config["metric"] rel = config["rel"] sd_dict = config["sd_dict"] is_log_dict = config["is_log_dict"] param_names = config["param_names"] green_scale = config["green_scale"] green_pink_scale = config["green_pink_scale"] blue_scale = config["blue_scale"] CONTOUR_CONFIG = { "autocolorscale": False, "autocontour": True, "contours": {"coloring": "heatmap"}, "hoverinfo": "x+y+z", "ncontours": int(density / 2), "type": "contour", } if rel: f_scale = reversed(green_pink_scale) if lower_is_better else green_pink_scale else: f_scale = green_scale f_contour_trace_base = { "colorbar": { "len": 0.875, "x": 0.45, "y": 0.5, "ticksuffix": "%" if rel else "", "tickfont": {"size": 8}, }, "colorscale": [(i / (len(f_scale) - 1), rgb(v)) for i, v in enumerate(f_scale)], "xaxis": "x", "yaxis": "y", # zmax and zmin are ignored if zauto is true "zauto": not rel, } sd_contour_trace_base = { "colorbar": { "len": 0.875, "x": 1, "y": 0.5, "ticksuffix": "%" if rel else "", "tickfont": {"size": 8}, }, "colorscale": [ (i / (len(blue_scale) - 1), rgb(v)) for i, v in enumerate(blue_scale) ], "xaxis": "x2", "yaxis": "y2", } # pyre-fixme[6]: Expected `Mapping[str, typing.Union[Dict[str, # typing.Union[Dict[str, int], float, str]], typing.List[Tuple[float, str]], bool, # str]]` for 1st param but got `Dict[str, typing.Union[Dict[str, str], int, # str]]`. f_contour_trace_base.update(CONTOUR_CONFIG) # pyre-fixme[6]: Expected `Mapping[str, typing.Union[Dict[str, # typing.Union[Dict[str, int], float, str]], typing.List[Tuple[float, str]], # str]]` for 1st param but got `Dict[str, typing.Union[Dict[str, str], int, # str]]`. sd_contour_trace_base.update(CONTOUR_CONFIG) # Format and add hovertext to contour plots. insample_param_values = {} for param_name in param_names: insample_param_values[param_name] = [] for arm_name in arm_data["in_sample"].keys(): insample_param_values[param_name].append( arm_data["in_sample"][arm_name]["parameters"][param_name] ) insample_arm_text = [] for arm_name in arm_data["in_sample"].keys(): atext = f"Arm {arm_name}" params = arm_data["in_sample"][arm_name]["parameters"] ys = arm_data["in_sample"][arm_name]["y"] ses = arm_data["in_sample"][arm_name]["se"] for yname in ys.keys(): sem_str = f"{ses[yname]}" if ses[yname] is None else f"{ses[yname]:.6g}" y_str = f"{ys[yname]}" if ys[yname] is None else f"{ys[yname]:.6g}" atext += f"<br>{yname}: {y_str} (SEM: {sem_str})" for pname in params.keys(): pval = params[pname] pstr = f"{pval:.6g}" if isinstance(pval, float) else f"{pval}" atext += f"<br>{pname}: {pstr}" insample_arm_text.append(atext) out_of_sample_param_values = {} for param_name in param_names: out_of_sample_param_values[param_name] = {} for generator_run_name in arm_data["out_of_sample"].keys(): out_of_sample_param_values[param_name][generator_run_name] = [] for arm_name in arm_data["out_of_sample"][generator_run_name].keys(): out_of_sample_param_values[param_name][generator_run_name].append( arm_data["out_of_sample"][generator_run_name][arm_name][ "parameters" ][param_name] ) out_of_sample_arm_text = {} for generator_run_name in arm_data["out_of_sample"].keys(): out_of_sample_arm_text[generator_run_name] = [ "<em>Candidate " + arm_name + "</em>" for arm_name in arm_data["out_of_sample"][generator_run_name].keys() ] # Populate `xbuttons`, which allows the user to select 1D slices of `f_dict` and # `sd_dict`, corresponding to all plots that have a certain parameter on the x-axis. # Number of traces for each pair of parameters trace_cnt = 4 + (len(arm_data["out_of_sample"]) * 2) xbuttons = [] ybuttons = [] for xvar in param_names: xbutton_data_args = {"x": [], "y": [], "z": []} for yvar in param_names: res = relativize_data( f_dict[xvar][yvar], sd_dict[xvar][yvar], rel, arm_data, metric ) f_final = res[0] sd_final = res[1] # transform to nested array f_plt = [] for ind in range(0, len(f_final), density): f_plt.append(f_final[ind : ind + density]) sd_plt = [] for ind in range(0, len(sd_final), density): sd_plt.append(sd_final[ind : ind + density]) # grid + in-sample xbutton_data_args["x"] += [ grid_dict[xvar], grid_dict[xvar], insample_param_values[xvar], insample_param_values[xvar], ] xbutton_data_args["y"] += [ grid_dict[yvar], grid_dict[yvar], insample_param_values[yvar], insample_param_values[yvar], ] xbutton_data_args["z"] += [f_plt, sd_plt, [], []] for generator_run_name in out_of_sample_param_values[xvar]: generator_run_x_vals = out_of_sample_param_values[xvar][ generator_run_name ] xbutton_data_args["x"] += [generator_run_x_vals] * 2 for generator_run_name in out_of_sample_param_values[yvar]: generator_run_y_vals = out_of_sample_param_values[yvar][ generator_run_name ] xbutton_data_args["y"] += [generator_run_y_vals] * 2 xbutton_data_args["z"] += [[]] * 2 xbutton_args = [ xbutton_data_args, { "xaxis.title": short_name(xvar), "xaxis2.title": short_name(xvar), "xaxis.range": axis_range(grid_dict[xvar], is_log_dict[xvar]), "xaxis2.range": axis_range(grid_dict[xvar], is_log_dict[xvar]), "xaxis.type": "log" if is_log_dict[xvar] else "linear", "xaxis2.type": "log" if is_log_dict[xvar] else "linear", }, ] xbuttons.append({"args": xbutton_args, "label": xvar, "method": "update"}) # Populate `ybuttons`, which uses the `visible` arg to mask the 1D slice of plots # produced by `xbuttons`, down to a single plot, so that only one element `f_dict` # and `sd_dict` remain. # No y button for first param so initial value is sane for y_idx in range(1, len(param_names)): visible = [False] * (len(param_names) * trace_cnt) for i in range(y_idx * trace_cnt, (y_idx + 1) * trace_cnt): visible[i] = True y_param = param_names[y_idx] ybuttons.append( { "args": [ {"visible": visible}, { "yaxis.title": short_name(y_param), "yaxis.range": axis_range( grid_dict[y_param], is_log_dict[y_param] ), "yaxis2.range": axis_range( grid_dict[y_param], is_log_dict[y_param] ), "yaxis.type": "log" if is_log_dict[y_param] else "linear", "yaxis2.type": "log" if is_log_dict[y_param] else "linear", }, ], "label": param_names[y_idx], "method": "update", } ) # calculate max of abs(outcome), used for colorscale # TODO(T37079623) Make this work for relative outcomes # let f_absmax = Math.max(Math.abs(Math.min(...f_final)), Math.max(...f_final)) traces = [] xvar = param_names[0] base_in_sample_arm_config = None # start symbol at 2 for out-of-sample candidate markers i = 2 for yvar_idx, yvar in enumerate(param_names): cur_visible = yvar_idx == 1 f_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx] sd_start = xbuttons[0]["args"][0]["z"][trace_cnt * yvar_idx + 1] # create traces f_trace = { "x": grid_dict[xvar], "y": grid_dict[yvar], "z": f_start, "visible": cur_visible, } for key in f_contour_trace_base.keys(): f_trace[key] = f_contour_trace_base[key] sd_trace = { "x": grid_dict[xvar], "y": grid_dict[yvar], "z": sd_start, "visible": cur_visible, } for key in sd_contour_trace_base.keys(): sd_trace[key] = sd_contour_trace_base[key] f_in_sample_arm_trace = {"xaxis": "x", "yaxis": "y"} sd_in_sample_arm_trace = {"showlegend": False, "xaxis": "x2", "yaxis": "y2"} base_in_sample_arm_config = { "hoverinfo": "text", "legendgroup": "In-sample", "marker": {"color": "black", "symbol": 1, "opacity": 0.5}, "mode": "markers", "name": "In-sample", "text": insample_arm_text, "type": "scatter", "visible": cur_visible, "x": insample_param_values[xvar], "y": insample_param_values[yvar], } for key in base_in_sample_arm_config.keys(): f_in_sample_arm_trace[key] = base_in_sample_arm_config[key] sd_in_sample_arm_trace[key] = base_in_sample_arm_config[key] traces += [f_trace, sd_trace, f_in_sample_arm_trace, sd_in_sample_arm_trace] # iterate over out-of-sample arms for generator_run_name in arm_data["out_of_sample"].keys(): traces.append( { "hoverinfo": "text", "legendgroup": generator_run_name, "marker": {"color": "black", "symbol": i, "opacity": 0.5}, "mode": "markers", "name": generator_run_name, "text": out_of_sample_arm_text[generator_run_name], "type": "scatter", "xaxis": "x", "x": out_of_sample_param_values[xvar][generator_run_name], "yaxis": "y", "y": out_of_sample_param_values[yvar][generator_run_name], "visible": cur_visible, } ) traces.append( { "hoverinfo": "text", "legendgroup": generator_run_name, "marker": {"color": "black", "symbol": i, "opacity": 0.5}, "mode": "markers", "name": "In-sample", "showlegend": False, "text": out_of_sample_arm_text[generator_run_name], "type": "scatter", "x": out_of_sample_param_values[xvar][generator_run_name], "xaxis": "x2", "y": out_of_sample_param_values[yvar][generator_run_name], "yaxis": "y2", "visible": cur_visible, } ) i += 1 # Initially visible yvar yvar = param_names[1] xrange = axis_range(grid_dict[xvar], is_log_dict[xvar]) yrange = axis_range(grid_dict[yvar], is_log_dict[yvar]) xtype = "log" if is_log_dict[xvar] else "linear" ytype = "log" if is_log_dict[yvar] else "linear" layout = { "annotations": [ { "font": {"size": 14}, "showarrow": False, "text": "Mean", "x": 0.25, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper", }, { "font": {"size": 14}, "showarrow": False, "text": "Standard Error", "x": 0.8, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper", }, { "x": 0.26, "y": -0.26, "xref": "paper", "yref": "paper", "text": "x-param:", "showarrow": False, "yanchor": "top", "xanchor": "left", }, { "x": 0.26, "y": -0.4, "xref": "paper", "yref": "paper", "text": "y-param:", "showarrow": False, "yanchor": "top", "xanchor": "left", }, ], "updatemenus": [ { "x": 0.35, "y": -0.29, "buttons": xbuttons, "xanchor": "left", "yanchor": "middle", "direction": "up", }, { "x": 0.35, "y": -0.43, "buttons": ybuttons, "xanchor": "left", "yanchor": "middle", "direction": "up", }, ], "autosize": False, "height": 450, "hovermode": "closest", "legend": {"orientation": "v", "x": 0, "y": -0.2, "yanchor": "top"}, "margin": {"b": 100, "l": 35, "pad": 0, "r": 35, "t": 35}, "width": 950, "xaxis": { "anchor": "y", "autorange": False, "domain": [0.05, 0.45], "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": short_name(xvar), "type": xtype, }, "xaxis2": { "anchor": "y2", "autorange": False, "domain": [0.6, 1], "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": short_name(xvar), "type": xtype, }, "yaxis": { "anchor": "x", "autorange": False, "domain": [0, 1], "exponentformat": "e", "range": yrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": short_name(yvar), "type": ytype, }, "yaxis2": { "anchor": "x2", "autorange": False, "domain": [0, 1], "exponentformat": "e", "range": yrange, "tickfont": {"size": 11}, "tickmode": "auto", "type": ytype, }, } return go.Figure(data=traces, layout=layout)