示例#1
0
    def setUp(self):
        self.branin_experiment = get_branin_experiment_with_multi_objective()
        sobol = Models.SOBOL(search_space=self.branin_experiment.search_space)
        sobol_run = sobol.gen(n=20)
        self.branin_experiment.new_batch_trial().add_generator_run(
            sobol_run
        ).run().mark_completed()
        data = self.branin_experiment.fetch_data()

        ms_gpei = ModelSpec(model_enum=Models.GPEI)
        ms_gpei.fit(experiment=self.branin_experiment, data=data)

        ms_gpkg = ModelSpec(model_enum=Models.GPKG)
        ms_gpkg.fit(experiment=self.branin_experiment, data=data)

        self.fitted_model_specs = [ms_gpei, ms_gpkg]

        self.model_selection_node = GenerationNode(
            model_specs=self.fitted_model_specs,
            best_model_selector=SingleDiagnosticBestModelSelector(
                diagnostic="Fisher exact test p",
                criterion=MetricAggregation.MEAN,
                metric_aggregation=DiagnosticCriterion.MIN,
            ),
        )
示例#2
0
 def test_construct(self):
     ms = ModelSpec(model_enum=Models.GPEI)
     with self.assertRaises(UserInputError):
         ms.gen(n=1)
     ms.fit(experiment=self.experiment, data=self.data)
     ms.gen(n=1)
     with self.assertRaises(NotImplementedError):
         ms.update(experiment=self.experiment, new_data=self.data)
示例#3
0
    def test_cross_validate_with_GP_model(self, mock_cv: Mock,
                                          mock_diagnostics: Mock):
        mock_enum = Mock()
        mock_enum.return_value = "fake-modelbridge"
        ms = ModelSpec(model_enum=mock_enum,
                       model_cv_kwargs={"test_key": "test-value"})
        ms.fit(
            experiment=self.experiment,
            data=self.experiment.trials[0].fetch_data(),
        )
        cv_results, cv_diagnostics = ms.cross_validate()
        mock_cv.assert_called_with(model="fake-modelbridge",
                                   test_key="test-value")
        mock_diagnostics.assert_called_with(["fake-cv-result"])

        self.assertIsNotNone(cv_results)
        self.assertIsNotNone(cv_diagnostics)

        with self.subTest("it caches CV results"):
            mock_cv.reset_mock()
            mock_diagnostics.reset_mock()

            cv_results, cv_diagnostics = ms.cross_validate()

            self.assertIsNotNone(cv_results)
            self.assertIsNotNone(cv_diagnostics)
            mock_cv.assert_not_called()
            mock_diagnostics.assert_not_called()

        with self.subTest("fit clears the CV cache"):
            mock_cv.reset_mock()
            mock_diagnostics.reset_mock()

            ms.fit(
                experiment=self.experiment,
                data=self.experiment.trials[0].fetch_data(),
            )
            cv_results, cv_diagnostics = ms.cross_validate()

            self.assertIsNotNone(cv_results)
            self.assertIsNotNone(cv_diagnostics)
            mock_cv.assert_called_with(model="fake-modelbridge",
                                       test_key="test-value")
            mock_diagnostics.assert_called_with(["fake-cv-result"])
示例#4
0
    def test_cross_validate_with_non_GP_model(
        self, mock_cv: Mock, mock_diagnostics: Mock
    ):
        mock_enum = Mock()
        mock_enum.return_value = "fake-modelbridge"
        ms = ModelSpec(model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"})
        ms.fit(
            experiment=self.experiment,
            data=self.experiment.trials[0].fetch_data(),
        )
        with warnings.catch_warnings(record=True) as w:
            cv_results, cv_diagnostics = ms.cross_validate()

        self.assertEqual(len(w), 1)
        self.assertIn("cannot be cross validated", str(w[0].message))
        self.assertIsNone(cv_results)
        self.assertIsNone(cv_diagnostics)

        mock_cv.assert_called_with(model="fake-modelbridge", test_key="test-value")
        mock_diagnostics.assert_not_called()