def _get_cv_plot_data(cv_results: List[CVResult]) -> PlotData: if len(cv_results) == 0: return PlotData( metrics=[], in_sample={}, out_of_sample=None, status_quo_name=None ) # arm_name -> Arm data insample_data: Dict[str, PlotInSampleArm] = {} # Assume input is well formed and this is consistent metric_names = cv_results[0].predicted.metric_names for rid, cv_result in enumerate(cv_results): arm_name = cv_result.observed.arm_name arm_data = { "name": cv_result.observed.arm_name, "y": {}, "se": {}, "parameters": cv_result.observed.features.parameters, "y_hat": {}, "se_hat": {}, "context_stratum": None, } for i, mname in enumerate(cv_result.observed.data.metric_names): # pyre-fixme[16]: Optional type has no attribute `__setitem__`. arm_data["y"][mname] = cv_result.observed.data.means[i] # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` # has no attribute `__setitem__`. arm_data["se"][mname] = np.sqrt(cv_result.observed.data.covariance[i][i]) for i, mname in enumerate(cv_result.predicted.metric_names): # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` # has no attribute `__setitem__`. arm_data["y_hat"][mname] = cv_result.predicted.means[i] # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], str]` # has no attribute `__setitem__`. arm_data["se_hat"][mname] = np.sqrt(cv_result.predicted.covariance[i][i]) # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got # `Optional[str]`. # pyre-fixme[6]: insample_data[f"{arm_name}_{rid}"] = PlotInSampleArm(**arm_data) return PlotData( metrics=metric_names, in_sample=insample_data, out_of_sample=None, status_quo_name=None, )
def interact_empirical_model_validation(batch: BatchTrial, data: Data) -> AxPlotConfig: """Compare the model predictions for the batch arms against observed data. Relies on the model predictions stored on the generator_runs of batch. Args: batch: Batch on which to perform analysis. data: Observed data for the batch. Returns: AxPlotConfig for the plot. """ insample_data: Dict[str, PlotInSampleArm] = {} metric_names = list(data.df["metric_name"].unique()) for struct in batch.generator_run_structs: generator_run = struct.generator_run if generator_run.model_predictions is None: continue for i, arm in enumerate(generator_run.arms): arm_data = { "name": arm.name_or_short_signature, "y": {}, "se": {}, "parameters": arm.parameters, "y_hat": {}, "se_hat": {}, "context_stratum": None, } predictions = generator_run.model_predictions for _, row in data.df[ data.df["arm_name"] == arm.name_or_short_signature ].iterrows(): metric_name = row["metric_name"] # pyre-fixme[16]: Optional type has no attribute `__setitem__`. arm_data["y"][metric_name] = row["mean"] # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], # str]` has no attribute `__setitem__`. arm_data["se"][metric_name] = row["sem"] # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. arm_data["y_hat"][metric_name] = predictions[0][metric_name][i] # pyre-fixme[16]: Item `None` of `Union[None, Dict[typing.Any, # typing.Any], Dict[str, typing.Union[None, bool, float, int, str]], # str]` has no attribute `__setitem__`. arm_data["se_hat"][metric_name] = predictions[1][metric_name][ metric_name ][i] # pyre-fixme[6]: Expected `Optional[Dict[str, Union[float, str]]]` for 1s... insample_data[arm.name_or_short_signature] = PlotInSampleArm(**arm_data) if not insample_data: raise ValueError("No model predictions present on the batch.") plot_data = PlotData( metrics=metric_names, in_sample=insample_data, out_of_sample=None, status_quo_name=None, ) fig = _obs_vs_pred_dropdown_plot(data=plot_data, rel=False) fig["layout"]["title"] = "Cross-validation" return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)
def get_plot_data( model: ModelBridge, generator_runs_dict: Dict[str, GeneratorRun], metric_names: Optional[Set[str]] = None, fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, ) -> Tuple[PlotData, RawData, Dict[str, TParameterization]]: """Format data object with metrics for in-sample and out-of-sample arms. Calculate both observed and predicted metrics for in-sample arms. Calculate predicted metrics for out-of-sample arms passed via the `generator_runs_dict` argument. In PlotData, in-sample observations are merged with IVW. In RawData, they are left un-merged and given as a list of dictionaries, one for each observation and having keys 'arm_name', 'mean', and 'sem'. Args: model: The model. generator_runs_dict: a mapping from generator run name to generator run. metric_names: Restrict predictions to this set. If None, all metrics in the model will be returned. fixed_features: Fixed features to use when making model predictions. data_selector: Function for selecting observations for plotting. Returns: A tuple containing - PlotData object with in-sample and out-of-sample predictions. - List of observations like:: {'metric_name': 'likes', 'arm_name': '0_1', 'mean': 1., 'sem': 0.1}. - Mapping from arm name to parameters. """ metrics_plot = model.metric_names if metric_names is None else metric_names in_sample_plot, raw_data, cond_name_to_parameters = _get_in_sample_arms( model=model, metric_names=metrics_plot, fixed_features=fixed_features, data_selector=data_selector, ) out_of_sample_plot = _get_out_of_sample_arms( model=model, generator_runs_dict=generator_runs_dict, metric_names=metrics_plot, fixed_features=fixed_features, ) # pyre-fixme[16]: `Optional` has no attribute `arm_name`. status_quo_name = None if model.status_quo is None else model.status_quo.arm_name plot_data = PlotData( metrics=list(metrics_plot), in_sample=in_sample_plot, out_of_sample=out_of_sample_plot, status_quo_name=status_quo_name, ) return plot_data, raw_data, cond_name_to_parameters
def _get_batch_comparison_plot_data( observations: List[Observation], batch_x: int, batch_y: int, rel: bool = False, status_quo_name: Optional[str] = None, ) -> PlotData: """Compute PlotData for comparing repeated arms across trials. Args: observations: List of observations. batch_x: Batch for x-axis. batch_y: Batch for y-axis. rel: Whether to relativize data against status_quo arm. status_quo_name: Name of the status_quo arm. Returns: PlotData: a plot data object. """ if rel and status_quo_name is None: raise ValueError("Experiment status quo must be set for rel=True") x_observations = { observation.arm_name: observation for observation in observations if observation.features.trial_index == batch_x } y_observations = { observation.arm_name: observation for observation in observations if observation.features.trial_index == batch_y } # Assume input is well formed and metric_names are consistent across observations metric_names = observations[0].data.metric_names insample_data: Dict[str, PlotInSampleArm] = {} for arm_name, x_observation in x_observations.items(): # Restrict to arms present in both trials if arm_name not in y_observations: continue y_observation = y_observations[arm_name] arm_data = { "name": arm_name, "y": {}, "se": {}, "parameters": x_observation.features.parameters, "y_hat": {}, "se_hat": {}, "context_stratum": None, } for i, mname in enumerate(x_observation.data.metric_names): # pyre-fixme[16]: Optional type has no attribute `__setitem__`. arm_data["y"][mname] = x_observation.data.means[i] arm_data["se"][mname] = np.sqrt( x_observation.data.covariance[i][i]) for i, mname in enumerate(y_observation.data.metric_names): arm_data["y_hat"][mname] = y_observation.data.means[i] arm_data["se_hat"][mname] = np.sqrt( y_observation.data.covariance[i][i]) # Expected `str` for 2nd anonymous parameter to call `dict.__setitem__` but got # `Optional[str]`. # pyre-fixme[6]: insample_data[arm_name] = PlotInSampleArm(**arm_data) return PlotData( metrics=metric_names, in_sample=insample_data, out_of_sample=None, status_quo_name=status_quo_name, )