Exemple #1
0
class SurrogateTest(TestCase):
    def setUp(self):
        self.botorch_model_class = SingleTaskGP
        self.mll_class = ExactMarginalLogLikelihood
        self.device = torch.device("cpu")
        self.dtype = torch.float
        self.X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]],
                              dtype=self.dtype,
                              device=self.device)

        self.Y = torch.tensor([[3.0], [4.0]],
                              dtype=self.dtype,
                              device=self.device)
        self.Yvar = torch.tensor([[0.0], [2.0]],
                                 dtype=self.dtype,
                                 device=self.device)

        self.training_data = TrainingData(X=self.X, Y=self.Y, Yvar=self.Yvar)
        self.surrogate_kwargs = self.botorch_model_class.construct_inputs(
            self.training_data)
        self.surrogate = Surrogate(
            botorch_model_class=self.botorch_model_class,
            mll_class=self.mll_class)
        self.bounds = [(0.0, 1.0), (1.0, 4.0), (2.0, 5.0)]
        self.task_features = []
        self.feature_names = ["x1", "x2", "x3"]
        self.metric_names = ["y"]
        self.fidelity_features = []
        self.target_fidelities = {1: 1.0}
        self.fixed_features = {1: 2.0}
        self.refit = True
        self.objective_weights = torch.tensor([-1.0, 1.0],
                                              dtype=self.dtype,
                                              device=self.device)
        self.outcome_constraints = (torch.tensor([[1.0]]), torch.tensor([[0.5]
                                                                         ]))
        self.linear_constraints = (
            torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
            torch.tensor([[0.5], [1.0]]),
        )
        self.options = {}

    @patch(f"{CURRENT_PATH}.Kernel")
    @patch(f"{CURRENT_PATH}.Likelihood")
    def test_init(self, mock_Likelihood, mock_Kernel):
        self.assertEqual(self.surrogate.botorch_model_class,
                         self.botorch_model_class)
        self.assertEqual(self.surrogate.mll_class, self.mll_class)
        with self.assertRaisesRegex(NotImplementedError,
                                    "Customizing likelihood"):
            Surrogate(botorch_model_class=self.botorch_model_class,
                      likelihood=Likelihood())
        with self.assertRaisesRegex(NotImplementedError, "Customizing kernel"):
            Surrogate(botorch_model_class=self.botorch_model_class,
                      kernel_class=Kernel())

    def test_model_property(self):
        with self.assertRaisesRegex(
                ValueError, "BoTorch `Model` has not yet been constructed."):
            self.surrogate.model

    def test_training_data_property(self):
        with self.assertRaisesRegex(
                ValueError,
                "Underlying BoTorch `Model` has not yet received its training_data.",
        ):
            self.surrogate.training_data

    def test_dtype_property(self):
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        self.assertEqual(self.dtype, self.surrogate.dtype)

    def test_device_property(self):
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        self.assertEqual(self.device, self.surrogate.device)

    def test_from_BoTorch(self):
        surrogate = Surrogate.from_BoTorch(
            self.botorch_model_class(**self.surrogate_kwargs))
        self.assertIsInstance(surrogate.model, self.botorch_model_class)
        self.assertFalse(surrogate._should_reconstruct)

    @patch(f"{CURRENT_PATH}.SingleTaskGP.__init__", return_value=None)
    def test_construct(self, mock_GP):
        base_surrogate = Surrogate(botorch_model_class=Model)
        with self.assertRaisesRegex(TypeError,
                                    "Cannot construct an abstract model."):
            base_surrogate.construct(
                training_data=self.training_data,
                fidelity_features=self.fidelity_features,
            )
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        mock_GP.assert_called_with(train_X=self.X, train_Y=self.Y)

    @patch(f"{CURRENT_PATH}.SingleTaskGP.load_state_dict", return_value=None)
    @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood")
    @patch(f"{SURROGATE_PATH}.fit_gpytorch_model")
    def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict):
        surrogate = Surrogate(
            botorch_model_class=self.botorch_model_class,
            mll_class=ExactMarginalLogLikelihood,
        )
        # Checking that model is None before `fit` (and `construct`) calls.
        self.assertIsNone(surrogate._model)
        # Should instantiate mll and `fit_gpytorch_model` when `state_dict`
        # is `None`.
        surrogate.fit(
            training_data=self.training_data,
            bounds=self.bounds,
            task_features=self.task_features,
            feature_names=self.feature_names,
            metric_names=self.metric_names,
            fidelity_features=self.fidelity_features,
            target_fidelities=self.target_fidelities,
            refit=self.refit,
        )
        mock_state_dict.assert_not_called()
        mock_MLL.assert_called_once()
        mock_fit_gpytorch.assert_called_once()
        mock_state_dict.reset_mock()
        mock_MLL.reset_mock()
        mock_fit_gpytorch.reset_mock()
        # Should `load_state_dict` when `state_dict` is not `None`
        # and `refit` is `False`.
        state_dict = {}
        surrogate.fit(
            training_data=self.training_data,
            bounds=self.bounds,
            task_features=self.task_features,
            feature_names=self.feature_names,
            metric_names=self.metric_names,
            fidelity_features=self.fidelity_features,
            target_fidelities=self.target_fidelities,
            refit=False,
            state_dict=state_dict,
        )
        mock_state_dict.assert_called_once()
        mock_MLL.assert_not_called()
        mock_fit_gpytorch.assert_not_called()

    @patch(f"{SURROGATE_PATH}.predict_from_model")
    def test_predict(self, mock_predict):
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        self.surrogate.predict(X=self.X)
        mock_predict.assert_called_with(model=self.surrogate.model, X=self.X)

    def test_best_in_sample_point(self):
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        # `best_in_sample_point` requires `objective_weights`
        with patch(f"{SURROGATE_PATH}.best_in_sample_point",
                   return_value=None) as mock_best_in_sample:
            with self.assertRaisesRegex(ValueError, "Could not obtain"):
                self.surrogate.best_in_sample_point(bounds=self.bounds,
                                                    objective_weights=None)
        with patch(f"{SURROGATE_PATH}.best_in_sample_point",
                   return_value=(self.X, 0.0)) as mock_best_in_sample:
            best_point, observed_value = self.surrogate.best_in_sample_point(
                bounds=self.bounds,
                objective_weights=self.objective_weights,
                outcome_constraints=self.outcome_constraints,
                linear_constraints=self.linear_constraints,
                fixed_features=self.fixed_features,
                options=self.options,
            )
            mock_best_in_sample.assert_called_with(
                Xs=[self.training_data.X],
                model=self.surrogate,
                bounds=self.bounds,
                objective_weights=self.objective_weights,
                outcome_constraints=self.outcome_constraints,
                linear_constraints=self.linear_constraints,
                fixed_features=self.fixed_features,
                options=self.options,
            )

    @patch(f"{ACQUISITION_PATH}.Acquisition.__init__", return_value=None)
    @patch(
        f"{ACQUISITION_PATH}.Acquisition.optimize",
        return_value=([torch.tensor([0.0])], [torch.tensor([1.0])]),
    )
    @patch(
        f"{SURROGATE_PATH}.pick_best_out_of_sample_point_acqf_class",
        return_value=(qSimpleRegret, {
            Keys.SAMPLER: SobolQMCNormalSampler
        }),
    )
    def test_best_out_of_sample_point(self, mock_best_point_util,
                                      mock_acqf_optimize, mock_acqf_init):
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        # currently cannot use function with fixed features
        with self.assertRaisesRegex(NotImplementedError, "Fixed features"):
            self.surrogate.best_out_of_sample_point(
                bounds=self.bounds,
                objective_weights=self.objective_weights,
                fixed_features=self.fixed_features,
            )
        candidate, acqf_value = self.surrogate.best_out_of_sample_point(
            bounds=self.bounds,
            objective_weights=self.objective_weights,
            outcome_constraints=self.outcome_constraints,
            linear_constraints=self.linear_constraints,
            fidelity_features=self.fidelity_features,
            target_fidelities=self.target_fidelities,
            options=self.options,
        )
        mock_acqf_init.assert_called_with(
            surrogate=self.surrogate,
            botorch_acqf_class=qSimpleRegret,
            bounds=self.bounds,
            objective_weights=self.objective_weights,
            outcome_constraints=self.outcome_constraints,
            linear_constraints=self.linear_constraints,
            fixed_features=None,
            target_fidelities=self.target_fidelities,
            options={Keys.SAMPLER: SobolQMCNormalSampler},
        )
        self.assertTrue(torch.equal(candidate, torch.tensor([0.0])))
        self.assertTrue(torch.equal(acqf_value, torch.tensor([1.0])))

    @patch(f"{SURROGATE_PATH}.Surrogate.fit")
    def test_update(self, mock_fit):
        self.surrogate.construct(training_data=self.training_data,
                                 fidelity_features=self.fidelity_features)
        # Call `fit` by default
        self.surrogate.update(
            training_data=self.training_data,
            bounds=self.bounds,
            task_features=self.task_features,
            feature_names=self.feature_names,
            metric_names=self.metric_names,
            fidelity_features=self.fidelity_features,
            refit=self.refit,
        )
        mock_fit.assert_called_with(
            training_data=self.training_data,
            bounds=self.bounds,
            task_features=self.task_features,
            feature_names=self.feature_names,
            metric_names=self.metric_names,
            fidelity_features=self.fidelity_features,
            candidate_metadata=None,
            state_dict=self.surrogate.model.state_dict,
            refit=self.refit,
        )
        # If should not be reconstructed, raise Error
        self.surrogate._should_reconstruct = False
        with self.assertRaisesRegex(
                NotImplementedError,
                ".* models that should not be re-constructed"):
            self.surrogate.update(
                training_data=self.training_data,
                bounds=self.bounds,
                task_features=self.task_features,
                feature_names=self.feature_names,
                metric_names=self.metric_names,
                fidelity_features=self.fidelity_features,
                refit=self.refit,
            )
Exemple #2
0
class SurrogateTest(TestCase):
    def setUp(self):
        self.botorch_model_class = SingleTaskGP
        self.mll_class = ExactMarginalLogLikelihood
        self.device = torch.device("cpu")
        self.dtype = torch.float
        self.Xs, self.Ys, self.Yvars, self.bounds, _, _, _ = get_torch_test_data(
            dtype=self.dtype)
        self.training_data = TrainingData.from_block_design(X=self.Xs[0],
                                                            Y=self.Ys[0],
                                                            Yvar=self.Yvars[0])
        self.surrogate_kwargs = self.botorch_model_class.construct_inputs(
            self.training_data)
        self.surrogate = Surrogate(
            botorch_model_class=self.botorch_model_class,
            mll_class=self.mll_class)
        self.search_space_digest = SearchSpaceDigest(
            feature_names=["x1", "x2"],
            bounds=self.bounds,
            target_fidelities={1: 1.0},
        )
        self.metric_names = ["y"]
        self.fixed_features = {1: 2.0}
        self.refit = True
        self.objective_weights = torch.tensor([-1.0, 1.0],
                                              dtype=self.dtype,
                                              device=self.device)
        self.outcome_constraints = (torch.tensor([[1.0]]), torch.tensor([[0.5]
                                                                         ]))
        self.linear_constraints = (
            torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
            torch.tensor([[0.5], [1.0]]),
        )
        self.options = {}

    @patch(f"{CURRENT_PATH}.Kernel")
    @patch(f"{CURRENT_PATH}.Likelihood")
    def test_init(self, mock_Likelihood, mock_Kernel):
        self.assertEqual(self.surrogate.botorch_model_class,
                         self.botorch_model_class)
        self.assertEqual(self.surrogate.mll_class, self.mll_class)
        with self.assertRaisesRegex(NotImplementedError,
                                    "Customizing likelihood"):
            Surrogate(botorch_model_class=self.botorch_model_class,
                      likelihood=Likelihood())
        with self.assertRaisesRegex(NotImplementedError, "Customizing kernel"):
            Surrogate(botorch_model_class=self.botorch_model_class,
                      kernel_class=Kernel())

    @patch(f"{SURROGATE_PATH}.fit_gpytorch_model")
    def test_mll_options(self, _):
        mock_mll = MagicMock(self.mll_class)
        surrogate = Surrogate(
            botorch_model_class=self.botorch_model_class,
            mll_class=mock_mll,
            mll_options={"some_option": "some_value"},
        )
        surrogate.fit(
            training_data=self.training_data,
            search_space_digest=self.search_space_digest,
            metric_names=self.metric_names,
            refit=self.refit,
        )
        self.assertEqual(mock_mll.call_args[1]["some_option"], "some_value")

    def test_model_property(self):
        with self.assertRaisesRegex(
                ValueError, "BoTorch `Model` has not yet been constructed."):
            self.surrogate.model

    def test_training_data_property(self):
        with self.assertRaisesRegex(
                ValueError,
                "Underlying BoTorch `Model` has not yet received its training_data.",
        ):
            self.surrogate.training_data

    def test_dtype_property(self):
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        self.assertEqual(self.dtype, self.surrogate.dtype)

    def test_device_property(self):
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        self.assertEqual(self.device, self.surrogate.device)

    def test_from_botorch(self):
        surrogate = Surrogate.from_botorch(
            self.botorch_model_class(**self.surrogate_kwargs))
        self.assertIsInstance(surrogate.model, self.botorch_model_class)
        self.assertTrue(surrogate._constructed_manually)

    @patch(f"{CURRENT_PATH}.SingleTaskGP.__init__", return_value=None)
    def test_construct(self, mock_GP):
        with self.assertRaises(NotImplementedError):
            # Base `Model` does not implement `construct_inputs`.
            Surrogate(botorch_model_class=Model).construct(
                training_data=self.training_data,
                fidelity_features=self.search_space_digest.fidelity_features,
            )
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        mock_GP.assert_called_once()
        call_kwargs = mock_GP.call_args[1]
        self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0]))
        self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0]))
        self.assertFalse(self.surrogate._constructed_manually)

        # Check that `model_options` passed to the `Surrogate` constructor are
        # properly propagated.
        with patch.object(
                SingleTaskGP,
                "construct_inputs",
                wraps=SingleTaskGP.construct_inputs) as mock_construct_inputs:
            surrogate = Surrogate(
                botorch_model_class=self.botorch_model_class,
                mll_class=self.mll_class,
                model_options={"some_option": "some_value"},
            )
            surrogate.construct(self.training_data)
            mock_construct_inputs.assert_called_with(
                training_data=self.training_data, some_option="some_value")

    @patch(f"{CURRENT_PATH}.SingleTaskGP.load_state_dict", return_value=None)
    @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood")
    @patch(f"{SURROGATE_PATH}.fit_gpytorch_model")
    def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict):
        surrogate = Surrogate(
            botorch_model_class=self.botorch_model_class,
            mll_class=ExactMarginalLogLikelihood,
        )
        # Checking that model is None before `fit` (and `construct`) calls.
        self.assertIsNone(surrogate._model)
        # Should instantiate mll and `fit_gpytorch_model` when `state_dict`
        # is `None`.
        surrogate.fit(
            training_data=self.training_data,
            search_space_digest=self.search_space_digest,
            metric_names=self.metric_names,
            refit=self.refit,
        )
        # Check that training data is correctly passed through to the
        # BoTorch `Model`.
        self.assertTrue(
            torch.equal(
                surrogate.model.train_inputs[0],
                self.surrogate_kwargs.get("train_X"),
            ))
        self.assertTrue(
            torch.equal(
                surrogate.model.train_targets,
                self.surrogate_kwargs.get("train_Y").squeeze(1),
            ))
        mock_state_dict.assert_not_called()
        mock_MLL.assert_called_once()
        mock_fit_gpytorch.assert_called_once()
        mock_state_dict.reset_mock()
        mock_MLL.reset_mock()
        mock_fit_gpytorch.reset_mock()
        # Should `load_state_dict` when `state_dict` is not `None`
        # and `refit` is `False`.
        state_dict = {"state_attribute": "value"}
        surrogate.fit(
            training_data=self.training_data,
            search_space_digest=self.search_space_digest,
            metric_names=self.metric_names,
            refit=False,
            state_dict=state_dict,
        )
        mock_state_dict.assert_called_once()
        mock_MLL.assert_not_called()
        mock_fit_gpytorch.assert_not_called()

    @patch(f"{SURROGATE_PATH}.predict_from_model")
    def test_predict(self, mock_predict):
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        self.surrogate.predict(X=self.Xs[0])
        mock_predict.assert_called_with(model=self.surrogate.model,
                                        X=self.Xs[0])

    def test_best_in_sample_point(self):
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        # `best_in_sample_point` requires `objective_weights`
        with patch(f"{SURROGATE_PATH}.best_in_sample_point",
                   return_value=None) as mock_best_in_sample:
            with self.assertRaisesRegex(ValueError, "Could not obtain"):
                self.surrogate.best_in_sample_point(
                    search_space_digest=self.search_space_digest,
                    objective_weights=None)
        with patch(f"{SURROGATE_PATH}.best_in_sample_point",
                   return_value=(self.Xs[0], 0.0)) as mock_best_in_sample:
            best_point, observed_value = self.surrogate.best_in_sample_point(
                search_space_digest=self.search_space_digest,
                objective_weights=self.objective_weights,
                outcome_constraints=self.outcome_constraints,
                linear_constraints=self.linear_constraints,
                fixed_features=self.fixed_features,
                options=self.options,
            )
            mock_best_in_sample.assert_called_with(
                Xs=[self.training_data.X],
                model=self.surrogate,
                bounds=self.search_space_digest.bounds,
                objective_weights=self.objective_weights,
                outcome_constraints=self.outcome_constraints,
                linear_constraints=self.linear_constraints,
                fixed_features=self.fixed_features,
                options=self.options,
            )

    @patch(f"{ACQUISITION_PATH}.Acquisition.__init__", return_value=None)
    @patch(
        f"{ACQUISITION_PATH}.Acquisition.optimize",
        return_value=([torch.tensor([0.0])], [torch.tensor([1.0])]),
    )
    @patch(
        f"{SURROGATE_PATH}.pick_best_out_of_sample_point_acqf_class",
        return_value=(qSimpleRegret, {
            Keys.SAMPLER: SobolQMCNormalSampler
        }),
    )
    def test_best_out_of_sample_point(self, mock_best_point_util,
                                      mock_acqf_optimize, mock_acqf_init):
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        # currently cannot use function with fixed features
        with self.assertRaisesRegex(NotImplementedError, "Fixed features"):
            self.surrogate.best_out_of_sample_point(
                search_space_digest=self.search_space_digest,
                objective_weights=self.objective_weights,
                fixed_features=self.fixed_features,
            )
        candidate, acqf_value = self.surrogate.best_out_of_sample_point(
            search_space_digest=self.search_space_digest,
            objective_weights=self.objective_weights,
            outcome_constraints=self.outcome_constraints,
            linear_constraints=self.linear_constraints,
            options=self.options,
        )
        mock_acqf_init.assert_called_with(
            surrogate=self.surrogate,
            botorch_acqf_class=qSimpleRegret,
            search_space_digest=self.search_space_digest,
            objective_weights=self.objective_weights,
            outcome_constraints=self.outcome_constraints,
            linear_constraints=self.linear_constraints,
            fixed_features=None,
            options={Keys.SAMPLER: SobolQMCNormalSampler},
        )
        self.assertTrue(torch.equal(candidate, torch.tensor([0.0])))
        self.assertTrue(torch.equal(acqf_value, torch.tensor([1.0])))

    @patch(f"{CURRENT_PATH}.SingleTaskGP.load_state_dict", return_value=None)
    @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood")
    @patch(f"{SURROGATE_PATH}.fit_gpytorch_model")
    def test_update(self, mock_fit_gpytorch, mock_MLL, mock_state_dict):
        self.surrogate.construct(
            training_data=self.training_data,
            fidelity_features=self.search_space_digest.fidelity_features,
        )
        # Check that correct arguments are passed to `fit`.
        with patch(f"{SURROGATE_PATH}.Surrogate.fit") as mock_fit:
            # Call `fit` by default
            self.surrogate.update(
                training_data=self.training_data,
                search_space_digest=self.search_space_digest,
                metric_names=self.metric_names,
                refit=self.refit,
                state_dict={"key": "val"},
            )
            mock_fit.assert_called_with(
                training_data=self.training_data,
                search_space_digest=self.search_space_digest,
                metric_names=self.metric_names,
                candidate_metadata=None,
                refit=self.refit,
                state_dict={"key": "val"},
            )

        # Check that the training data is correctly passed through to the
        # BoTorch `Model`.
        Xs, Ys, Yvars, bounds, _, _, _ = get_torch_test_data(dtype=self.dtype,
                                                             offset=1.0)
        training_data = TrainingData.from_block_design(X=Xs[0],
                                                       Y=Ys[0],
                                                       Yvar=Yvars[0])
        surrogate_kwargs = self.botorch_model_class.construct_inputs(
            training_data)
        self.surrogate.update(
            training_data=training_data,
            search_space_digest=self.search_space_digest,
            metric_names=self.metric_names,
            refit=self.refit,
            state_dict={"key": "val"},
        )
        self.assertTrue(
            torch.equal(
                self.surrogate.model.train_inputs[0],
                surrogate_kwargs.get("train_X"),
            ))
        self.assertTrue(
            torch.equal(
                self.surrogate.model.train_targets,
                surrogate_kwargs.get("train_Y").squeeze(1),
            ))

        # If should not be reconstructed, check that error is raised.
        self.surrogate._constructed_manually = True
        with self.assertRaisesRegex(NotImplementedError,
                                    ".* constructed manually"):
            self.surrogate.update(
                training_data=self.training_data,
                search_space_digest=self.search_space_digest,
                metric_names=self.metric_names,
                refit=self.refit,
            )

    def test_serialize_attributes_as_kwargs(self):
        expected = self.surrogate.__dict__
        self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(),
                         expected)