def plot_slice_plotly( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> go.Figure: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: plot of objective vs. parameter value """ pd, cntp, f_plt, rd, grid, _, _, _, fv, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, trial_index=trial_index, ) config = { "arm_data": pd, "arm_name_to_parameters": cntp, "f": f_plt, "fit_data": rd, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fv, "sd": sd_plt, "is_log": ls, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metric = config["metric"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] traces = slice_config_to_trace( arm_data, arm_name_to_parameters, f, fit_data, grid, metric, param, rel, setx, sd, is_log, True, ) # layout xrange = axis_range(grid, is_log) xtype = "log" if is_log else "linear" layout = { "hovermode": "closest", "xaxis": { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": { "size": 11 }, "tickmode": "auto", "title": param, "type": xtype, }, "yaxis": { "anchor": "x", "tickfont": { "size": 11 }, "tickmode": "auto", "title": metric, }, } return go.Figure(data=traces, layout=layout)
def interact_slice_plotly( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> go.Figure: """Create interactive plot with predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: interactive plot of objective vs. parameter """ if generator_runs_dict is None: generator_runs_dict = {} metric_names = list(model.metric_names) # Populate `pbuttons`, which allows the user to select 1D slices of parameter # space with the chosen parameter on the x-axis. range_parameters = get_range_parameters(model) param_names = [parameter.name for parameter in range_parameters] pbuttons = [] init_traces = [] xaxis_init_format = {} first_param_bool = True should_replace_slice_values = fixed_features is not None for param_name in param_names: pbutton_data_args = {"x": [], "y": [], "error_y": []} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data_dict = {} raw_data_dict = {} sd_plt_dict: Dict[str, Dict[str, np.ndarray]] = {} cond_name_to_parameters_dict = {} is_log_dict: Dict[str, bool] = {} if should_replace_slice_values: slice_values = not_none(fixed_features).parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values, trial_index) prediction_features = [] for x in grid: predf = deepcopy(not_none(fixed_features)) predf.parameters = fixed_values.copy() predf.parameters[param_name] = x prediction_features.append(predf) f, cov = model.predict(prediction_features) for metric_name in metric_names: pd, cntp, f_plt, rd, _, _, _, _, _, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, ) plot_data_dict[metric_name] = pd raw_data_dict[metric_name] = rd cond_name_to_parameters_dict[metric_name] = cntp sd_plt_dict[metric_name] = np.sqrt(cov[metric_name][metric_name]) is_log_dict[metric_name] = ls config = { "arm_data": plot_data_dict, "arm_name_to_parameters": cond_name_to_parameters_dict, "f": f, "fit_data": raw_data_dict, "grid": grid, "metrics": metric_names, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt_dict, "is_log": is_log_dict, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metrics = config["metrics"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] # layout xrange = axis_range(grid, is_log[metrics[0]]) xtype = "log" if is_log_dict[metrics[0]] else "linear" for i, metric in enumerate(metrics): cur_visible = i == 0 metric = metrics[i] traces = slice_config_to_trace( arm_data[metric], arm_name_to_parameters[metric], f[metric], fit_data[metric], grid, metric, param, rel, setx, sd[metric], is_log[metric], cur_visible, ) pbutton_data_args["x"] += [trace["x"] for trace in traces] pbutton_data_args["y"] += [trace["y"] for trace in traces] pbutton_data_args["error_y"] += [{ "type": "data", "array": trace["error_y"]["array"], "visible": True, "color": "black", } if "error_y" in trace and "array" in trace["error_y"] else [] for trace in traces] if first_param_bool: init_traces.extend(traces) pbutton_args = [ pbutton_data_args, { "xaxis.title": param_name, "xaxis.range": xrange, "xaxis.type": xtype, }, ] pbuttons.append({ "args": pbutton_args, "label": param_name, "method": "update" }) if first_param_bool: xaxis_init_format = { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": { "size": 11 }, "tickmode": "auto", "title": param_name, "type": xtype, } first_param_bool = False # Populate mbuttons, which allows the user to select which metric to plot mbuttons = [] for i, metric in enumerate(metrics): trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys()) visible = [False] * (len(metrics) * trace_cnt) for j in range(i * trace_cnt, (i + 1) * trace_cnt): visible[j] = True mbuttons.append({ "method": "update", "args": [{ "visible": visible }, { "yaxis.title": metric }], "label": metric, }) layout = { "title": "Predictions for a 1-d slice of the parameter space", "annotations": [ { "showarrow": False, "text": "Choose metric:", "x": 0.225, "xanchor": "right", "xref": "paper", "y": -0.455, "yanchor": "bottom", "yref": "paper", }, { "showarrow": False, "text": "Choose parameter:", "x": 0.225, "xanchor": "right", "xref": "paper", "y": -0.305, "yanchor": "bottom", "yref": "paper", }, ], "updatemenus": [ { "y": -0.35, "x": 0.25, "xanchor": "left", "yanchor": "top", "buttons": mbuttons, "direction": "up", }, { "y": -0.2, "x": 0.25, "xanchor": "left", "yanchor": "top", "buttons": pbuttons, "direction": "up", }, ], "hovermode": "closest", "xaxis": xaxis_init_format, "yaxis": { "anchor": "x", "autorange": True, "tickfont": { "size": 11 }, "tickmode": "auto", "title": metrics[0], }, } return go.Figure(data=init_traces, layout=layout)
def interact_slice( model: ModelBridge, param_name: str, metric_name: str = "", generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. """ if generator_runs_dict is None: generator_runs_dict = {} metric_names = list(model.metric_names) parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data_dict = {} raw_data_dict = {} sd_plt_dict: Dict[str, Dict[str, np.ndarray]] = {} cond_name_to_parameters_dict = {} is_log_dict: Dict[str, bool] = {} if fixed_features is not None: slice_values = fixed_features.parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values) prediction_features = [] for x in grid: predf = deepcopy(fixed_features) predf.parameters = fixed_values.copy() predf.parameters[param_name] = x prediction_features.append(predf) f, cov = model.predict(prediction_features) for metric_name in metric_names: pd, cntp, f_plt, rd, _, _, _, _, _, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, ) plot_data_dict[metric_name] = pd raw_data_dict[metric_name] = rd cond_name_to_parameters_dict[metric_name] = cntp sd_plt_dict[metric_name] = np.sqrt(cov[metric_name][metric_name]) is_log_dict[metric_name] = ls config = { "arm_data": plot_data_dict, "arm_name_to_parameters": cond_name_to_parameters_dict, "f": f, "fit_data": raw_data_dict, "grid": grid, "metrics": metric_names, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt_dict, "is_log": is_log_dict, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metrics = config["metrics"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] traces = [] for i, metric in enumerate(metrics): cur_visible = i == 0 metric = metrics[i] traces.extend( slice_config_to_trace( arm_data[metric], arm_name_to_parameters[metric], f[metric], fit_data[metric], grid, metric, param, rel, setx, sd[metric], is_log[metric], cur_visible, ) ) # layout xrange = axis_range(grid, is_log[metrics[0]]) xtype = "log" if is_log[metrics[0]] else "linear" buttons = [] for i, metric in enumerate(metrics): trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys()) * 2 visible = [False] * (len(metrics) * trace_cnt) for j in range(i * trace_cnt, (i + 1) * trace_cnt): visible[j] = True buttons.append( { "method": "update", "args": [{"visible": visible}, {"yaxis.title": metric}], "label": metric, } ) layout = { "title": "Predictions for a 1-d slice of the parameter space", "annotations": [ { "showarrow": False, "text": "Choose metric:", "x": 0.225, "xanchor": "center", "xref": "paper", "y": 1.005, "yanchor": "bottom", "yref": "paper", } ], "updatemenus": [{"y": 1.1, "x": 0.5, "yanchor": "top", "buttons": buttons}], "hovermode": "closest", "xaxis": { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": {"size": 11}, "tickmode": "auto", "title": param, "type": xtype, }, "yaxis": { "anchor": "x", "autorange": True, "tickfont": {"size": 11}, "tickmode": "auto", "title": metrics[0], }, } fig = go.Figure(data=traces, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)