Beispiel #1
0
    def get_model_predictions(
        self, metric_names: Optional[List[str]] = None
    ) -> Dict[int, Dict[str, Tuple[float, float]]]:
        """Retrieve model-estimated means and covariances for all metrics.
        Note: this function retrieves the predictions for the 'in-sample' arms,
        which means that the return mapping on this function will only contain
        predictions for trials that have been completed with data.

        Args:
            metric_names: Names of the metrics, for which to retrieve predictions.
                All metrics on experiment will be retrieved if this argument was
                not specified.

        Returns:
            A mapping from trial index to a mapping of metric names to tuples
            of predicted metric mean and SEM, of form:
            { trial_index -> { metric_name: ( mean, SEM ) } }.
        """
        if self.generation_strategy.model is None:  # pragma: no cover
            raise ValueError("No model has been instantiated yet.")
        if metric_names is None and self.experiment.metrics is None:
            raise ValueError(  # pragma: no cover
                "No metrics to retrieve specified on the experiment or as "
                "argument to `get_model_predictions`."
            )
        arm_info, _, _ = _get_in_sample_arms(
            model=not_none(self.generation_strategy.model),
            metric_names=set(metric_names)
            if metric_names is not None
            else set(not_none(self.experiment.metrics).keys()),
        )
        trials = checked_cast_dict(int, Trial, self.experiment.trials)

        return {
            trial_index: {
                m: (
                    arm_info[not_none(trials[trial_index].arm).name].y_hat[m],
                    arm_info[not_none(trials[trial_index].arm).name].se_hat[m],
                )
                for m in arm_info[not_none(trials[trial_index].arm).name].y_hat
            }
            for trial_index in trials
            if not_none(trials[trial_index].arm).name in arm_info
        }
Beispiel #2
0
 def test_checked_cast_dict(self):
     self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1})
     with self.assertRaises(ValueError):
         checked_cast_dict(str, int, {"some": 1.0})
     with self.assertRaises(ValueError):
         checked_cast_dict(str, int, {1: 1})