def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig: if df.empty: raise NoDataError("No Data on Feature Importances found.") df.set_index(df.columns[0], inplace=True) data = [ go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h") for column_name in df.columns ] fig = subplots.make_subplots( rows=len(df.columns), cols=1, subplot_titles=df.columns, print_grid=False, shared_xaxes=True, ) for idx, item in enumerate(data): fig.append_trace(item, idx + 1, 1) fig.layout.showlegend = False fig.layout.margin = go.layout.Margin( l=8 * min(max(len(idx) for idx in df.index), 75) # noqa E741 ) fig.layout.title = title return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlotConfig: """Compare the model predictions for the batch arms against observed data. Relies on the model predictions stored on the generator_runs of batch. Args: batch: Batch on which to perform analysis. data: Observed data for the batch. Returns: AxPlotConfig for the plot. """ insample_data: Dict[str, PlotInSampleArm] = {} metric_names = list(data.df["metric_name"].unique()) for struct in batch.generator_run_structs: generator_run = struct.generator_run if generator_run.model_predictions is None: continue for i, arm in enumerate(generator_run.arms): arm_data = { "name": arm.name_or_short_signature, "y": {}, "se": {}, "parameters": arm.parameters, "y_hat": {}, "se_hat": {}, "context_stratum": None, } predictions = generator_run.model_predictions for _, row in data.df[ data.df["arm_name"] == arm.name_or_short_signature ].iterrows(): metric_name = row["metric_name"] # pyre-fixme[16]: Optional type has no attribute `__setitem__`. arm_data["y"][metric_name] = row["mean"] # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], # str]` has no attribute `__setitem__`. arm_data["se"][metric_name] = row["sem"] # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. arm_data["y_hat"][metric_name] = predictions[0][metric_name][i] # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], # str]` has no attribute `__setitem__`. arm_data["se_hat"][metric_name] = predictions[1][metric_name][ metric_name ][i] # pyre-fixme[6]: Expected `Optional[Dict[str, Union[float, str]]]` for 1s... insample_data[arm.name_or_short_signature] = PlotInSampleArm(**arm_data) if not insample_data: raise ValueError("No model predictions present on the batch.") plot_data = PlotData( metrics=metric_names, in_sample=insample_data, out_of_sample=None, status_quo_name=None, ) fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False) fig["layout"]["title"] = "Cross-validation" return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def interact_batch_comparison( observations: List[Observation], experiment: Experiment, batch_x: int, batch_y: int, rel: bool = False, status_quo_name: Optional[str] = None, ) -> AxPlotConfig: """Compare repeated arms from two trials; select metric via dropdown. Args: observations: List of observations to compute comparison. batch_x: Index of batch for x-axis. batch_y: Index of bach for y-axis. rel: Whether to relativize data against status_quo arm. status_quo_name: Name of the status_quo arm. """ if isinstance(experiment, MultiTypeExperiment): observations = convert_mt_observations(observations, experiment) if not status_quo_name and experiment.status_quo: status_quo_name = not_none(experiment.status_quo).name plot_data = _get_batch_comparison_plot_data( observations, batch_x, batch_y, rel=rel, status_quo_name=status_quo_name) fig = _obs_vs_pred_dropdown_plot( data=plot_data, rel=rel, xlabel="Batch {}".format(batch_x), ylabel="Batch {}".format(batch_y), ) fig["layout"]["title"] = "Repeated arms across trials" return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_relative_feature_importance(model: ModelBridge) -> AxPlotConfig: """Create a stacked bar chart of feature importances per metric""" importances = [] for metric_name in sorted(model.metric_names): try: vals: Dict[str, Any] = model.feature_importances(metric_name) vals["index"] = metric_name importances.append(vals) except Exception: logger.warning( "Model for {} does not support feature importances.".format( metric_name)) df = pd.DataFrame(importances) df.set_index("index", inplace=True) df = df.div(df.sum(axis=1), axis=0) data = [ go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h") for column_name in df.columns ] layout = go.Layout( margin=go.layout.Margin(l=250), # noqa E741 barmode="grouped", yaxis={"title": ""}, xaxis={"title": "Relative Feature importance"}, showlegend=False, title="Relative Feature Importance per Metric", ) fig = go.Figure(data=data, layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_by_metric_plotly to AxPlotConfig""" return AxPlotConfig( data=plot_feature_importance_by_metric_plotly(model), plot_type=AxPlotTypes.GENERIC, )
def optimization_trace_all_methods( y_dict: Dict[str, np.ndarray], optimum: Optional[float] = None, title: str = "", ylabel: str = "", trace_colors: List[Tuple[int]] = DISCRETE_COLOR_SCALE, optimum_color: Tuple[int] = COLORS.ORANGE.value, ) -> AxPlotConfig: """Plots a comparison of optimization traces with 2-SEM bands for multiple methods on the same problem. Args: y: a mapping of method names to (r x t) arrays, where r is the number of runs in the test, and t is the number of trials. optimum: value of the optimal objective. title: Title for this plot. ylabel: Label for y axis trace_colors: tuples of 3 int values representing RGB colors to use for different methods shown in the combination plot. Defaults to Ax discrete color scale. optimum_color: tuple of 3 int values representing an RGB color. Defaults to orange. Returns: AxPlotConfig: plot of the comparison of optimization traces with IQR """ data: List[go.Scatter] = [] for i, (method, y) in enumerate(y_dict.items()): # If there are more traces than colors, start reusing colors. color = trace_colors[i % len(trace_colors)] trace = mean_trace_scatter(y=y, trace_color=color, legend_label=method) # pyre-fixme[23]: Unable to unpack single value, 2 were expected. lower, upper = sem_range_scatter(y=y, trace_color=color, legend_label=method) data.extend([lower, trace, upper]) if optimum is not None: num_iterations = max(y.shape[1] for y in y_dict.values()) data.append( optimum_objective_scatter( optimum=optimum, num_iterations=num_iterations, optimum_color=optimum_color, )) layout = go.Layout( # pyre-ignore[16]: ...graph_objs` has no attr. `Layout` title=title, showlegend=True, yaxis={"title": ylabel}, xaxis={"title": "Iteration"}, ) return AxPlotConfig( # pyre-ignore[16]: ...graph_objs` has no attr. `Figure` data=go.Figure(layout=layout, data=data), plot_type=AxPlotTypes.GENERIC, )
def plot_contour( model: ModelBridge, param_x: str, param_y: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, ) -> AxPlotConfig: """Plot predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_x: Name of parameter that will be sliced on x-axis param_y: Name of parameter that will be sliced on y-axis metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. """ if param_x == param_y: raise ValueError("Please select different parameters for x- and y-dimensions.") data, f_plt, sd_plt, grid_x, grid_y, scales = _get_contour_predictions( model=model, x_param_name=param_x, y_param_name=param_y, metric=metric_name, generator_runs_dict=generator_runs_dict, density=density, slice_values=slice_values, ) config = { "arm_data": data, "blue_scale": BLUE_SCALE, "density": density, "f": f_plt, "green_scale": GREEN_SCALE, "green_pink_scale": GREEN_PINK_SCALE, "grid_x": grid_x, "grid_y": grid_y, "lower_is_better": lower_is_better, "metric": metric_name, "rel": relative, "sd": sd_plt, "xvar": param_x, "yvar": param_y, "x_is_log": scales["x"], "y_is_log": scales["y"], } return AxPlotConfig(config, plot_type=AxPlotTypes.CONTOUR)
def optimization_trace_single_method( y: np.ndarray, optimum: Optional[float] = None, model_transitions: Optional[List[int]] = None, title: str = "", ylabel: str = "", hover_labels: Optional[List[str]] = None, trace_color: Tuple[int] = COLORS.STEELBLUE.value, optimum_color: Tuple[int] = COLORS.ORANGE.value, generator_change_color: Tuple[int] = COLORS.TEAL.value, optimization_direction: Optional[str] = "passthrough", plot_trial_points: bool = False, trial_points_color: Tuple[int] = COLORS.LIGHT_PURPLE.value, ) -> AxPlotConfig: """Plots an optimization trace with mean and 2 SEMs Args: y: (r x t) array; result to plot, with r runs and t trials optimum: value of the optimal objective model_transitions: iterations, before which generators changed title: title for this plot. ylabel: label for the Y-axis. hover_labels: optional, text to show on hover; list where the i-th value corresponds to the i-th value in the value of the `y` argument. trace_color: tuple of 3 int values representing an RGB color for plotting running optimum. Defaults to blue. optimum_color: tuple of 3 int values representing an RGB color. Defaults to orange. generator_change_color: tuple of 3 int values representing an RGB color. Defaults to teal. optimization_direction: str, "minimize" will plot running minimum, "maximize" will plot running maximum, "passthrough" (default) will plot y as lines, None does not plot running optimum) plot_trial_points: bool, whether to plot the objective for each trial, as supplied in y (default False for backward compatibility) trial_points_color: tuple of 3 int values representing an RGB color for plotting trial points. Defaults to light purple. Returns: AxPlotConfig: plot of the optimization trace with IQR """ return AxPlotConfig( data=optimization_trace_single_method_plotly( y=y, optimum=optimum, model_transitions=model_transitions, title=title, ylabel=ylabel, hover_labels=hover_labels, trace_color=trace_color, optimum_color=optimum_color, generator_change_color=generator_change_color, optimization_direction=optimization_direction, plot_trial_points=plot_trial_points, trial_points_color=trial_points_color, ), plot_type=AxPlotTypes.GENERIC, )
def plot_feature_importance_by_feature( model: ModelBridge, relative: bool = True, caption: str = "" ) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_by_feature_plotly to AxPlotConfig""" return AxPlotConfig( data=plot_feature_importance_by_feature_plotly(model, relative, caption), plot_type=AxPlotTypes.GENERIC, )
def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig: """ Calculates and plots the marginal effects -- the effect of changing one factor away from the randomized distribution of the experiment and fixing it at a particular level. Args: model: Model to use for estimating effects metric: The metric for which to plot marginal effects. Returns: AxPlotConfig of the marginal effects """ plot_data, _, _ = get_plot_data(model, {}, {metric}) arm_dfs = [] for arm in plot_data.in_sample.values(): arm_df = pd.DataFrame(arm.parameters, index=[arm.name]) arm_df["mean"] = arm.y_hat[metric] arm_df["sem"] = arm.se_hat[metric] arm_dfs.append(arm_df) effect_table = marginal_effects(pd.concat(arm_dfs, 0)) varnames = effect_table["Name"].unique() data: List[Any] = [] for varname in varnames: var_df = effect_table[effect_table["Name"] == varname] data += [ # pyre-ignore[16] go.Bar( x=var_df["Level"], y=var_df["Beta"], error_y={ "type": "data", "array": var_df["SE"] }, name=varname, ) ] fig = tools.make_subplots( cols=len(varnames), rows=1, subplot_titles=list(varnames), print_grid=False, shared_yaxes=True, ) for idx, item in enumerate(data): fig.append_trace(item, 1, idx + 1) fig.layout.showlegend = False # fig.layout.margin = go.Margin(l=2, r=2) fig.layout.title = "Marginal Effects by Factor" fig.layout.yaxis = { "title": "% better than experiment average", "hoverformat": ".{}f".format(DECIMALS), } return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_slice( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, ) -> AxPlotConfig: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. """ pd, cntp, f_plt, rd, grid, _, _, _, fv, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, ) config = { "arm_data": pd, "arm_name_to_parameters": cntp, "f": f_plt, "fit_data": rd, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fv, "sd": sd_plt, "is_log": ls, } return AxPlotConfig(config, plot_type=AxPlotTypes.SLICE)
def plot_contour( model: ModelBridge, param_x: str, param_y: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, lower_is_better: bool = False, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> AxPlotConfig: """Plot predictions for a 2-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_x: Name of parameter that will be sliced on x-axis param_y: Name of parameter that will be sliced on y-axis metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. lower_is_better: Lower values for metric are better. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: AxPlotConfig: contour plot of objective vs. parameter values """ return AxPlotConfig( data=plot_contour_plotly( model=model, param_x=param_x, param_y=param_y, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, lower_is_better=lower_is_better, fixed_features=fixed_features, trial_index=trial_index, ), plot_type=AxPlotTypes.GENERIC, )
def get_running_trials_per_minute( experiment: Experiment, show_until_latest_end_plus_timedelta: timedelta = FIVE_MINUTES, ) -> AxPlotConfig: trial_runtimes: List[Tuple[int, datetime, Optional[datetime]]] = [ ( trial.index, not_none(trial._time_run_started), trial. _time_completed, # Time trial was completed, failed, or abandoned. ) for trial in experiment.trials.values() if trial._time_run_started is not None ] earliest_start = min(tr[1] for tr in trial_runtimes) latest_end = max( not_none(tr[2]) for tr in trial_runtimes if tr[2] is not None) running_during = { ts: [ t[0] # Trial index. for t in trial_runtimes # Trial is running during a given timestamp if: # 1) it's run start time is at/before the timestamp, # 2) it's completion time has not yet come or is after the timestamp. if t[1] <= ts and (True if t[2] is None else not_none(t[2]) >= ts) ] for ts in timestamps_in_range( earliest_start, latest_end + show_until_latest_end_plus_timedelta, timedelta(seconds=60), ) } num_running_at_ts = { ts: len(trials) for ts, trials in running_during.items() } scatter = go.Scatter( x=list(num_running_at_ts.keys()), y=[num_running_at_ts[ts] for ts in num_running_at_ts], ) return AxPlotConfig( data=go.Figure( layout=go.Layout( title="Number of running trials during experiment"), data=[scatter], ), plot_type=AxPlotTypes.GENERIC, )
def interact_cross_validation( cv_results: List[CVResult], show_context: bool = True ) -> AxPlotConfig: """Interactive cross-validation (CV) plotting; select metric via dropdown. Note: uses the Plotly version of dropdown (which means that all data is stored within the notebook). Args: cv_results: cross-validation results. show_context: if True, show context on hover. """ data = _get_cv_plot_data(cv_results) fig = _obs_vs_pred_dropdown_plot(data=data, rel=False, show_context=show_context) fig["layout"]["title"] = "Cross-validation" return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def scatter_plot_with_pareto_frontier( Y: np.ndarray, Y_pareto: np.ndarray, metric_x: str, metric_y: str, reference_point: Tuple[float, float], minimize: bool = True, ) -> AxPlotConfig: return AxPlotConfig( data=scatter_plot_with_pareto_frontier_plotly( Y=Y, Y_pareto=Y_pareto, metric_x=metric_x, metric_y=metric_y, reference_point=reference_point, ), plot_type=AxPlotTypes.GENERIC, )
def plot_parallel_coordinates( experiment: Experiment, ignored_names: Optional[List[str]] = None) -> AxPlotConfig: """Plot trials as a parallel coordinates graph Args: experiment: Experiment containing trials to plot ignored_names: Metrics present in the experiment data we wish to exclude from the final plot. By default we ignore ["generation_method", "trial_status", "arm_name"] Returns: AxPlotConfig: parellel coordinates plot of all experiment trials """ fig = plot_parallel_coordinates_plotly(experiment=experiment, ignored_names=ignored_names) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def interact_slice( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> AxPlotConfig: """Create interactive plot with predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: AxPlotConfig: interactive plot of objective vs. parameter """ return AxPlotConfig( data=interact_slice_plotly( model=model, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, trial_index=trial_index, ), plot_type=AxPlotTypes.GENERIC, )
def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig: """Plot bandit rollout from ane experiement.""" categories: List[str] = [] arms: Dict[str, Dict[str, Any]] = {} data = [] index = 0 for trial in sorted(experiment.trials.values(), key=lambda trial: trial.index): if not isinstance(trial, BatchTrial): raise ValueError( "Bandit rollout graph is not supported for BaseTrial." ) # pragma: no cover category = f"Round {trial.index}" categories.append(category) for arm, weight in trial.normalized_arm_weights(total=100).items(): if arm.name not in arms: arms[arm.name] = { "index": index, "name": arm.name, "x": [], "y": [], "text": [], } index += 1 arms[arm.name]["x"].append(category) arms[arm.name]["y"].append(weight) arms[arm.name]["text"].append("{:.2f}%".format(weight)) for key in arms.keys(): data.append(arms[key]) # pyre-fixme[6]: Expected `typing.Tuple[...g.Tuple[int, int, int]`. colors = [rgba(c) for c in MIXED_SCALE] config = {"data": data, "categories": categories, "colors": colors} return AxPlotConfig(config, plot_type=AxPlotTypes.BANDIT_ROLLOUT)
def interact_cross_validation( cv_results: List[CVResult], show_context: bool = True ) -> AxPlotConfig: """Interactive cross-validation (CV) plotting; select metric via dropdown. Note: uses the Plotly version of dropdown (which means that all data is stored within the notebook). Args: cv_results: cross-validation results. show_context: if True, show context on hover. Returns an AxPlotConfig """ return AxPlotConfig( data=interact_cross_validation_plotly( cv_results=cv_results, show_context=show_context ), plot_type=AxPlotTypes.GENERIC, )
def optimization_trace_single_method( y: np.ndarray, optimum: Optional[float] = None, model_transitions: Optional[List[int]] = None, title: str = "", ylabel: str = "", hover_labels: Optional[List[str]] = None, trace_color: Tuple[int] = COLORS.STEELBLUE.value, optimum_color: Tuple[int] = COLORS.ORANGE.value, generator_change_color: Tuple[int] = COLORS.TEAL.value, ) -> AxPlotConfig: """Plots an optimization trace with mean and 2 SEMs Args: y: (r x t) array; result to plot, with r runs and t trials optimum: value of the optimal objective model_transitions: iterations, before which generators changed title: title for this plot. ylabel: label for the Y-axis. hover_labels: optional, text to show on hover; list where the i-th value corresponds to the i-th value in the value of the `y` argument. trace_color: tuple of 3 int values representing an RGB color. Defaults to orange. optimum_color: tuple of 3 int values representing an RGB color. Defaults to orange. generator_change_color: tuple of 3 int values representing an RGB color. Defaults to orange. Returns: AxPlotConfig: plot of the optimization trace with IQR """ trace = mean_trace_scatter(y=y, trace_color=trace_color, hover_labels=hover_labels) lower, upper = sem_range_scatter(y=y, trace_color=trace_color) layout = go.Layout( title=title, showlegend=True, yaxis={"title": ylabel}, xaxis={"title": "Iteration"}, ) data = [lower, trace, upper] if optimum is not None: data.append( optimum_objective_scatter(optimum=optimum, num_iterations=y.shape[1], optimum_color=optimum_color)) if model_transitions is not None: # pragma: no cover y_lower = np.min(np.percentile(y, 25, axis=0)) y_upper = np.max(np.percentile(y, 75, axis=0)) if optimum is not None and optimum < y_lower: y_lower = optimum if optimum is not None and optimum > y_upper: y_upper = optimum data.extend( model_transitions_scatter( model_transitions=model_transitions, y_range=[y_lower, y_upper], generator_change_color=generator_change_color, )) return AxPlotConfig(data=go.Figure(layout=layout, data=data), plot_type=AxPlotTypes.GENERIC)
def tile_cross_validation( cv_results: List[CVResult], show_arm_details_on_hover: bool = True, show_context: bool = True, ) -> AxPlotConfig: """Tile version of CV plots; sorted by 'best fitting' outcomes. Plots are sorted in decreasing order using the p-value of a Fisher exact test statistic. Args: cv_results: cross-validation results. include_measurement_error: if True, include measurement_error metrics in plot. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is True. show_context: if True (default), display context on hover. """ data = _get_cv_plot_data(cv_results) metrics = data.metrics # make subplots (2 plots per row) nrows = int(np.ceil(len(metrics) / 2)) ncols = min(len(metrics), 2) fig = tools.make_subplots( rows=nrows, cols=ncols, print_grid=False, subplot_titles=tuple(metrics), horizontal_spacing=0.15, vertical_spacing=0.30 / nrows, ) for i, metric in enumerate(metrics): y_hat = [] se_hat = [] y_raw = [] se_raw = [] for arm in data.in_sample.values(): y_hat.append(arm.y_hat[metric]) se_hat.append(arm.se_hat[metric]) y_raw.append(arm.y[metric]) se_raw.append(arm.se[metric]) min_, max_ = _get_min_max_with_errors(y_raw, y_hat, se_raw, se_hat) fig.append_trace(_diagonal_trace(min_, max_), int(np.floor(i / 2)) + 1, i % 2 + 1) fig.append_trace( _error_scatter_trace( # Expected `List[typing.Union[PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st anonymous # parameter to call `ax.plot.scatter._error_scatter_trace` but # got `List[PlotInSampleArm]`. # pyre-fixme[6]: list(data.in_sample.values()), y_axis_var=PlotMetric(metric, True), x_axis_var=PlotMetric(metric, False), y_axis_label="Predicted", x_axis_label="Actual", hoverinfo="text", show_arm_details_on_hover=show_arm_details_on_hover, show_context=show_context, ), int(np.floor(i / 2)) + 1, i % 2 + 1, ) # if odd number of plots, need to manually remove the last blank subplot # generated by `tools.make_subplots` if len(metrics) % 2 == 1: del fig["layout"]["xaxis{}".format(nrows * ncols)] del fig["layout"]["yaxis{}".format(nrows * ncols)] # allocate 400 px per plot (equal aspect ratio) fig["layout"].update( title="Cross-Validation", # What should I replace this with? hovermode="closest", width=800, height=400 * nrows, font={"size": 10}, showlegend=False, ) # update subplot title size and the axis labels for i, ant in enumerate(fig["layout"]["annotations"]): ant["font"].update(size=12) fig["layout"]["xaxis{}".format(i + 1)].update(title="Actual Outcome", mirror=True, linecolor="black", linewidth=0.5) fig["layout"]["yaxis{}".format(i + 1)].update( title="Predicted Outcome", mirror=True, linecolor="black", linewidth=0.5) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def lattice_multiple_metrics( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, ) -> AxPlotConfig: """Plot raw values or predictions of combinations of two metrics for arms. Args: model: model to draw predictions from. generator_runs_dict: a mapping from generator run name to generator run. rel: if True, use relative effects. Default is True. show_arm_details_on_hover: if True, display parameterizations of arms on hover. Default is False. """ metrics = model.metric_names fig = tools.make_subplots( rows=len(metrics), cols=len(metrics), print_grid=False, shared_xaxes=False, shared_yaxes=False, ) plot_data, _, _ = get_plot_data( model, generator_runs_dict if generator_runs_dict is not None else {}, metrics) status_quo_arm = ( None if plot_data.status_quo_name is None # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. else plot_data.in_sample.get(plot_data.status_quo_name)) # iterate over all combinations of metrics and generate scatter traces for i, o1 in enumerate(metrics, start=1): for j, o2 in enumerate(metrics, start=1): if o1 != o2: # in-sample observed and predicted obs_insample_trace = _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(o1, pred=False, rel=rel), y_axis_var=PlotMetric(o2, pred=False, rel=rel), status_quo_arm=status_quo_arm, showlegend=(i == 1 and j == 2), legendgroup="In-sample", visible=False, show_arm_details_on_hover=show_arm_details_on_hover, ) predicted_insample_trace = _error_scatter_trace( # Expected `List[Union[PlotInSampleArm, # PlotOutOfSampleArm]]` for 1st anonymous parameter to call # `ax.plot.scatter._error_scatter_trace` but got # `List[PlotInSampleArm]`. # pyre-fixme[6]: list(plot_data.in_sample.values()), x_axis_var=PlotMetric(o1, pred=True, rel=rel), y_axis_var=PlotMetric(o2, pred=True, rel=rel), status_quo_arm=status_quo_arm, legendgroup="In-sample", showlegend=(i == 1 and j == 2), visible=True, show_arm_details_on_hover=show_arm_details_on_hover, ) fig.append_trace(obs_insample_trace, j, i) fig.append_trace(predicted_insample_trace, j, i) # iterate over models here for k, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1): fig.append_trace( _error_scatter_trace( list(cand_arms.values()), x_axis_var=PlotMetric(o1, pred=True, rel=rel), y_axis_var=PlotMetric(o2, pred=True, rel=rel), status_quo_arm=status_quo_arm, name=generator_run_name, color=DISCRETE_COLOR_SCALE[k], showlegend=(i == 1 and j == 2), legendgroup=generator_run_name, show_arm_details_on_hover=show_arm_details_on_hover, ), j, i, ) else: # if diagonal is set to True, add box plots fig.append_trace( go.Box( y=[arm.y[o1] for arm in plot_data.in_sample.values()], name=None, marker={"color": rgba(COLORS.STEELBLUE.value)}, showlegend=False, legendgroup="In-sample", visible=False, hoverinfo="none", ), j, i, ) fig.append_trace( go.Box( y=[ arm.y_hat[o1] for arm in plot_data.in_sample.values() ], name=None, marker={"color": rgba(COLORS.STEELBLUE.value)}, showlegend=False, legendgroup="In-sample", hoverinfo="none", ), j, i, ) for k, (generator_run_name, cand_arms) in enumerate( (plot_data.out_of_sample or {}).items(), start=1): fig.append_trace( go.Box( y=[arm.y_hat[o1] for arm in cand_arms.values()], name=None, marker={"color": rgba(DISCRETE_COLOR_SCALE[k])}, showlegend=False, legendgroup=generator_run_name, hoverinfo="none", ), j, i, ) fig["layout"].update( height=800, width=960, font={"size": 10}, hovermode="closest", legend={ "orientation": "h", "x": 0, "y": 1.05, "xanchor": "left", "yanchor": "middle", }, updatemenus=[ { "x": 0.35, "y": 1.08, "xanchor": "left", "yanchor": "middle", "buttons": [ { "args": [{ "error_x.width": 0, "error_x.thickness": 0, "error_y.width": 0, "error_y.thickness": 0, }], "label": "No", "method": "restyle", }, { "args": [{ "error_x.width": 4, "error_x.thickness": 2, "error_y.width": 4, "error_y.thickness": 2, }], "label": "Yes", "method": "restyle", }, ], }, { "x": 0.1, "y": 1.08, "xanchor": "left", "yanchor": "middle", "buttons": [ { "args": [{ "visible": (([False, True] + [True] * len(plot_data.out_of_sample or {})) * (len(metrics)**2)) }], "label": "Modeled", "method": "restyle", }, { "args": [{ "visible": (([True, False] + [False] * len(plot_data.out_of_sample or {})) * (len(metrics)**2)) }], "label": "In-sample", "method": "restyle", }, ], }, ], annotations=[ { "x": 0.02, "y": 1.1, "xref": "paper", "yref": "paper", "text": "Type", "showarrow": False, "yanchor": "middle", "xanchor": "left", }, { "x": 0.30, "y": 1.1, "xref": "paper", "yref": "paper", "text": "Show CI", "showarrow": False, "yanchor": "middle", "xanchor": "left", }, ], ) # add metric names to axes - add to each subplot if boxplots on the # diagonal and axes are not shared; else, add to the leftmost y-axes # and bottom x-axes. for i, o in enumerate(metrics): pos_x = len(metrics) * len(metrics) - len(metrics) + i + 1 pos_y = 1 + (len(metrics) * i) fig["layout"]["xaxis{}".format(pos_x)].update(title=_wrap_metric(o), titlefont={"size": 10}) fig["layout"]["yaxis{}".format(pos_y)].update(title=_wrap_metric(o), titlefont={"size": 10}) # do not put x-axis ticks for boxplots boxplot_xaxes = [] for trace in fig["data"]: if trace["type"] == "box": # stores the xaxes which correspond to boxplot subplots # since we use xaxis1, xaxis2, etc, in plotly.py boxplot_xaxes.append("xaxis{}".format(trace["xaxis"][1:])) else: # clear all error bars since default is no CI trace["error_x"].update(width=0, thickness=0) trace["error_y"].update(width=0, thickness=0) for xaxis in boxplot_xaxes: fig["layout"][xaxis]["showticklabels"] = False return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_feature_importance_by_feature(model: ModelBridge, relative: bool = True) -> AxPlotConfig: """One plot per metric, showing importances by feature.""" traces = [] dropdown = [] for i, metric_name in enumerate(sorted(model.metric_names)): try: importances = model.feature_importances(metric_name) except NotImplementedError: logger.warning( f"Model for {metric_name} does not support feature importances." ) continue df = pd.DataFrame([{ "Factor": factor, "Importance": importance } for factor, importance in importances.items()]) if relative: df["Importance"] = df["Importance"].div(df["Importance"].sum()) df = df.sort_values("Importance") traces.append( go.Bar( name="Importance", orientation="h", visible=i == 0, x=df["Importance"], y=df["Factor"], )) is_visible = [False] * len(sorted(model.metric_names)) is_visible[i] = True dropdown.append({ "args": ["visible", is_visible], "label": metric_name, "method": "restyle" }) if not traces: raise NotImplementedError("No traces found for metric") updatemenus = [{ "x": 0, "y": 1, "yanchor": "top", "xanchor": "left", "buttons": dropdown, "pad": { "t": -40 }, # hack to put dropdown below title regardless of number of features }] features = traces[0].y title = ("Relative Feature Importances" if relative else "Absolute Feature Importances") layout = go.Layout( height=200 + len(features) * 20, hovermode="closest", margin=go.layout.Margin(l=8 * min(max(len(idx) for idx in features), 75) # noqa E741 ), showlegend=False, title=title, updatemenus=updatemenus, ) if relative: layout.update({"xaxis": {"tickformat": ".0%"}}) fig = go.Figure(data=traces, layout=layout) plot_fi = AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) return plot_fi
def plot_bandit_rollout(experiment: Experiment) -> AxPlotConfig: """Plot bandit rollout from ane experiement.""" categories: List[str] = [] arms: Dict[str, Dict[str, Any]] = {} data = [] index = 0 for trial in sorted(experiment.trials.values(), key=lambda trial: trial.index): if not isinstance(trial, BatchTrial): raise ValueError( "Bandit rollout graph is not supported for BaseTrial." ) # pragma: no cover category = f"Round {trial.index}" categories.append(category) for arm, weight in trial.normalized_arm_weights(total=100).items(): if arm.name not in arms: arms[arm.name] = { "index": index, "name": arm.name, "x": [], "y": [], "text": [], } index += 1 arms[arm.name]["x"].append(category) arms[arm.name]["y"].append(weight) arms[arm.name]["text"].append("{:.2f}%".format(weight)) for key in arms.keys(): data.append(arms[key]) # pyre-fixme[6]: Expected `typing.Tuple[...g.Tuple[int, int, int]`. colors = [rgba(c) for c in MIXED_SCALE] layout = go.Layout( # pyre-ignore[16] title="Rollout Process<br>Bandit Weight Graph", xaxis={ "title": "Rounds", "zeroline": False, "categoryorder": "array", "categoryarray": categories, }, yaxis={"title": "Percent", "showline": False}, barmode="stack", showlegend=False, margin={"r": 40}, ) bandit_config = {"type": "bar", "hoverinfo": "name+text", "width": 0.5} bandits = [ dict(bandit_config, marker={"color": colors[d["index"] % len(colors)]}, **d) for d in data ] for bandit in bandits: del bandit[ "index"] # Have to delete index or figure creation causes error fig = go.Figure(data=bandits, layout=layout) # pyre-ignore[16] return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_slice( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, ) -> AxPlotConfig: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. """ if generator_runs_dict is None: generator_runs_dict = {} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data, raw_data, cond_name_to_parameters = get_plot_data( model=model, generator_runs_dict=generator_runs_dict, metric_names={metric_name} ) fixed_values = get_fixed_values(model, slice_values) prediction_features = [] for x in grid: parameters = fixed_values.copy() parameters[param_name] = x # Here we assume context is None prediction_features.append(ObservationFeatures(parameters=parameters)) f, cov = model.predict(prediction_features) f_plt = f[metric_name] sd_plt = np.sqrt(cov[metric_name][metric_name]) config = { "arm_data": plot_data, "arm_name_to_parameters": cond_name_to_parameters, "f": f_plt, "fit_data": raw_data, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt, "is_log": parameter.log_scale, } return AxPlotConfig(config, plot_type=AxPlotTypes.SLICE)
def table_view_plot( experiment: Experiment, data: Data, use_empirical_bayes: bool = True, only_data_frame: bool = False, arm_noun: str = "arm", ): """Table of means and confidence intervals. Table is of the form: +-------+------------+-----------+ | arm | metric_1 | metric_2 | +=======+============+===========+ | 0_0 | mean +- CI | ... | +-------+------------+-----------+ | 0_1 | ... | ... | +-------+------------+-----------+ """ model_func = get_empirical_bayes_thompson if use_empirical_bayes else get_thompson model = model_func(experiment=experiment, data=data) # We don't want to include metrics from a collection, # or the chart will be too big to read easily. # Example: # experiment.metrics = { # 'regular_metric': Metric(), # 'collection_metric: CollectionMetric()', # collection_metric =[metric1, metric2] # } # model.metric_names = [regular_metric, metric1, metric2] # "exploded" out # We want to filter model.metric_names and get rid of metric1, metric2 metric_names = [ metric_name for metric_name in model.metric_names if metric_name in experiment.metrics ] metric_name_to_lower_is_better = { metric_name: experiment.metrics[metric_name].lower_is_better for metric_name in metric_names } plot_data, _, _ = get_plot_data( model=model, generator_runs_dict={}, # pyre-fixme[6]: Expected `Optional[typing.Set[str]]` for 3rd param but got # `List[str]`. metric_names=metric_names, ) if plot_data.status_quo_name: status_quo_arm = plot_data.in_sample.get(plot_data.status_quo_name) rel = True else: status_quo_arm = None rel = False records = [] colors = [] records_with_mean = [] records_with_ci = [] for metric_name in metric_names: arm_names, _, ys, ys_se = _error_scatter_data( # pyre-fixme[6]: Expected # `List[typing.Union[ax.plot.base.PlotInSampleArm, # ax.plot.base.PlotOutOfSampleArm]]` for 1st param but got # `List[ax.plot.base.PlotInSampleArm]`. arms=list(plot_data.in_sample.values()), y_axis_var=PlotMetric(metric_name, pred=True, rel=rel), x_axis_var=None, status_quo_arm=status_quo_arm, ) results_by_arm = list(zip(arm_names, ys, ys_se)) colors.append([ get_color( x=y, ci=Z * y_se, rel=rel, # pyre-fixme[6]: Expected `bool` for 4th param but got # `Optional[bool]`. reverse=metric_name_to_lower_is_better[metric_name], ) for (_, y, y_se) in results_by_arm ]) records.append([ "{:.3f} ± {:.3f}".format(y, Z * y_se) for (_, y, y_se) in results_by_arm ]) records_with_mean.append( {arm_name: y for (arm_name, y, _) in results_by_arm}) records_with_ci.append( {arm_name: Z * y_se for (arm_name, _, y_se) in results_by_arm}) if only_data_frame: return tuple( pd.DataFrame.from_records(records, index=metric_names) for records in [records_with_mean, records_with_ci]) def transpose(m): return [[m[j][i] for j in range(len(m))] for i in range(len(m[0]))] records = [[name.replace(":", " : ") for name in metric_names]] + transpose(records) colors = [["#ffffff"] * len(metric_names)] + transpose(colors) # pyre-fixme[6]: Expected `List[str]` for 1st param but got `List[float]`. header = [f"<b>{x}</b>" for x in [f"{arm_noun}s"] + arm_names] column_widths = [300] + [150] * len(arm_names) trace = go.Table( header={ "values": header, "align": ["left"] }, cells={ "values": records, "align": ["left"], "fill": { "color": colors } }, columnwidth=column_widths, ) layout = go.Layout( width=sum(column_widths), margin=go.layout.Margin(l=0, r=20, b=20, t=20, pad=4), # noqa E741 ) fig = go.Figure(data=[trace], layout=layout) return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_plotly to AxPlotConfig""" return AxPlotConfig( data=plot_feature_importance_plotly(df, title), plot_type=AxPlotTypes.GENERIC )
def optimization_times( fit_times: Dict[str, List[float]], gen_times: Dict[str, List[float]], title: str = "", ) -> AxPlotConfig: """Plots wall times for each method as a bar chart. Args: fit_times: A map from method name to a list of the model fitting times. gen_times: A map from method name to a list of the gen times. title: Title for this plot. Returns: AxPlotConfig with the plot """ # Compute means and SEs methods = list(fit_times.keys()) fit_res: Dict[str, Union[str, List[float]]] = {"name": "Fitting"} fit_res["mean"] = [np.mean(fit_times[m]) for m in methods] fit_res["2sems"] = [ 2 * np.std(fit_times[m]) / np.sqrt(len(fit_times[m])) for m in methods ] gen_res: Dict[str, Union[str, List[float]]] = {"name": "Generation"} gen_res["mean"] = [np.mean(gen_times[m]) for m in methods] gen_res["2sems"] = [ 2 * np.std(gen_times[m]) / np.sqrt(len(gen_times[m])) for m in methods ] total_mean: List[float] = [] total_2sems: List[float] = [] for m in methods: totals = np.array(fit_times[m]) + np.array(gen_times[m]) total_mean.append(np.mean(totals)) total_2sems.append(2 * np.std(totals) / np.sqrt(len(totals))) total_res: Dict[str, Union[str, List[float]]] = { "name": "Total", "mean": total_mean, "2sems": total_2sems, } # Construct plot data: List[go.Bar] = [] for i, res in enumerate([fit_res, gen_res, total_res]): data.append( go.Bar( x=methods, y=res["mean"], text=res["name"], textposition="auto", error_y={ "type": "data", "array": res["2sems"], "visible": True }, marker={ "color": rgba(DISCRETE_COLOR_SCALE[i]), "line": { "color": "rgb(0,0,0)", "width": 1.0 }, }, opacity=0.6, name=res["name"], )) layout = go.Layout( title=title, showlegend=False, yaxis={"title": "Time"}, xaxis={"title": "Method"}, ) return AxPlotConfig(data=go.Figure(layout=layout, data=data), plot_type=AxPlotTypes.GENERIC)
def plot_slice_plotly( model: ModelBridge, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> go.Figure: """Plot predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: plot of objective vs. parameter value """ pd, cntp, f_plt, rd, grid, _, _, _, fv, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, trial_index=trial_index, ) config = { "arm_data": pd, "arm_name_to_parameters": cntp, "f": f_plt, "fit_data": rd, "grid": grid, "metric": metric_name, "param": param_name, "rel": relative, "setx": fv, "sd": sd_plt, "is_log": ls, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metric = config["metric"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] traces = slice_config_to_trace( arm_data, arm_name_to_parameters, f, fit_data, grid, metric, param, rel, setx, sd, is_log, True, ) # layout xrange = axis_range(grid, is_log) xtype = "log" if is_log else "linear" layout = { "hovermode": "closest", "xaxis": { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": { "size": 11 }, "tickmode": "auto", "title": param, "type": xtype, }, "yaxis": { "anchor": "x", "tickfont": { "size": 11 }, "tickmode": "auto", "title": metric, }, } return go.Figure(data=traces, layout=layout)
def interact_slice_plotly( model: ModelBridge, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, slice_values: Optional[Dict[str, Any]] = None, fixed_features: Optional[ObservationFeatures] = None, trial_index: Optional[int] = None, ) -> go.Figure: """Create interactive plot with predictions for a 1-d slice of the parameter space. Args: model: ModelBridge that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo density: Number of points along slice to evaluate predictions. slice_values: A dictionary {name: val} for the fixed values of the other parameters. If not provided, then the status quo values will be used if there is a status quo, otherwise the mean of numeric parameters or the mode of choice parameters. Ignored if fixed_features is specified. fixed_features: An ObservationFeatures object containing the values of features (including non-parameter features like context) to be set in the slice. Returns: go.Figure: interactive plot of objective vs. parameter """ if generator_runs_dict is None: generator_runs_dict = {} metric_names = list(model.metric_names) # Populate `pbuttons`, which allows the user to select 1D slices of parameter # space with the chosen parameter on the x-axis. range_parameters = get_range_parameters(model) param_names = [parameter.name for parameter in range_parameters] pbuttons = [] init_traces = [] xaxis_init_format = {} first_param_bool = True should_replace_slice_values = fixed_features is not None for param_name in param_names: pbutton_data_args = {"x": [], "y": [], "error_y": []} parameter = get_range_parameter(model, param_name) grid = get_grid_for_parameter(parameter, density) plot_data_dict = {} raw_data_dict = {} sd_plt_dict: Dict[str, Dict[str, np.ndarray]] = {} cond_name_to_parameters_dict = {} is_log_dict: Dict[str, bool] = {} if should_replace_slice_values: slice_values = not_none(fixed_features).parameters else: fixed_features = ObservationFeatures(parameters={}) fixed_values = get_fixed_values(model, slice_values, trial_index) prediction_features = [] for x in grid: predf = deepcopy(not_none(fixed_features)) predf.parameters = fixed_values.copy() predf.parameters[param_name] = x prediction_features.append(predf) f, cov = model.predict(prediction_features) for metric_name in metric_names: pd, cntp, f_plt, rd, _, _, _, _, _, sd_plt, ls = _get_slice_predictions( model=model, param_name=param_name, metric_name=metric_name, generator_runs_dict=generator_runs_dict, relative=relative, density=density, slice_values=slice_values, fixed_features=fixed_features, ) plot_data_dict[metric_name] = pd raw_data_dict[metric_name] = rd cond_name_to_parameters_dict[metric_name] = cntp sd_plt_dict[metric_name] = np.sqrt(cov[metric_name][metric_name]) is_log_dict[metric_name] = ls config = { "arm_data": plot_data_dict, "arm_name_to_parameters": cond_name_to_parameters_dict, "f": f, "fit_data": raw_data_dict, "grid": grid, "metrics": metric_names, "param": param_name, "rel": relative, "setx": fixed_values, "sd": sd_plt_dict, "is_log": is_log_dict, } config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data arm_data = config["arm_data"] arm_name_to_parameters = config["arm_name_to_parameters"] f = config["f"] fit_data = config["fit_data"] grid = config["grid"] metrics = config["metrics"] param = config["param"] rel = config["rel"] setx = config["setx"] sd = config["sd"] is_log = config["is_log"] # layout xrange = axis_range(grid, is_log[metrics[0]]) xtype = "log" if is_log_dict[metrics[0]] else "linear" for i, metric in enumerate(metrics): cur_visible = i == 0 metric = metrics[i] traces = slice_config_to_trace( arm_data[metric], arm_name_to_parameters[metric], f[metric], fit_data[metric], grid, metric, param, rel, setx, sd[metric], is_log[metric], cur_visible, ) pbutton_data_args["x"] += [trace["x"] for trace in traces] pbutton_data_args["y"] += [trace["y"] for trace in traces] pbutton_data_args["error_y"] += [{ "type": "data", "array": trace["error_y"]["array"], "visible": True, "color": "black", } if "error_y" in trace and "array" in trace["error_y"] else [] for trace in traces] if first_param_bool: init_traces.extend(traces) pbutton_args = [ pbutton_data_args, { "xaxis.title": param_name, "xaxis.range": xrange, "xaxis.type": xtype, }, ] pbuttons.append({ "args": pbutton_args, "label": param_name, "method": "update" }) if first_param_bool: xaxis_init_format = { "anchor": "y", "autorange": False, "exponentformat": "e", "range": xrange, "tickfont": { "size": 11 }, "tickmode": "auto", "title": param_name, "type": xtype, } first_param_bool = False # Populate mbuttons, which allows the user to select which metric to plot mbuttons = [] for i, metric in enumerate(metrics): trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys()) visible = [False] * (len(metrics) * trace_cnt) for j in range(i * trace_cnt, (i + 1) * trace_cnt): visible[j] = True mbuttons.append({ "method": "update", "args": [{ "visible": visible }, { "yaxis.title": metric }], "label": metric, }) layout = { "title": "Predictions for a 1-d slice of the parameter space", "annotations": [ { "showarrow": False, "text": "Choose metric:", "x": 0.225, "xanchor": "right", "xref": "paper", "y": -0.455, "yanchor": "bottom", "yref": "paper", }, { "showarrow": False, "text": "Choose parameter:", "x": 0.225, "xanchor": "right", "xref": "paper", "y": -0.305, "yanchor": "bottom", "yref": "paper", }, ], "updatemenus": [ { "y": -0.35, "x": 0.25, "xanchor": "left", "yanchor": "top", "buttons": mbuttons, "direction": "up", }, { "y": -0.2, "x": 0.25, "xanchor": "left", "yanchor": "top", "buttons": pbuttons, "direction": "up", }, ], "hovermode": "closest", "xaxis": xaxis_init_format, "yaxis": { "anchor": "x", "autorange": True, "tickfont": { "size": 11 }, "tickmode": "auto", "title": metrics[0], }, } return go.Figure(data=init_traces, layout=layout)