예제 #1
0
 def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict):
     surrogate = Surrogate(
         botorch_model_class=self.botorch_model_class,
         mll_class=ExactMarginalLogLikelihood,
     )
     # Checking that model is None before `fit` (and `construct`) calls.
     self.assertIsNone(surrogate._model)
     # Should instantiate mll and `fit_gpytorch_model` when `state_dict`
     # is `None`.
     surrogate.fit(
         training_data=self.training_data,
         search_space_digest=self.search_space_digest,
         metric_names=self.metric_names,
         refit=self.refit,
     )
     mock_state_dict.assert_not_called()
     mock_MLL.assert_called_once()
     mock_fit_gpytorch.assert_called_once()
     mock_state_dict.reset_mock()
     mock_MLL.reset_mock()
     mock_fit_gpytorch.reset_mock()
     # Should `load_state_dict` when `state_dict` is not `None`
     # and `refit` is `False`.
     state_dict = {"state_attribute": "value"}
     surrogate.fit(
         training_data=self.training_data,
         search_space_digest=self.search_space_digest,
         metric_names=self.metric_names,
         refit=False,
         state_dict=state_dict,
     )
     mock_state_dict.assert_called_once()
     mock_MLL.assert_not_called()
     mock_fit_gpytorch.assert_not_called()
예제 #2
0
 def test_mll_options(self, _):
     mock_mll = MagicMock(self.mll_class)
     surrogate = Surrogate(
         botorch_model_class=self.botorch_model_class,
         mll_class=mock_mll,
         mll_options={"some_option": "some_value"},
     )
     surrogate.fit(
         training_data=self.training_data,
         search_space_digest=self.search_space_digest,
         metric_names=self.metric_names,
         refit=self.refit,
     )
     self.assertEqual(mock_mll.call_args[1]["some_option"], "some_value")
예제 #3
0
 def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict):
     surrogate = Surrogate(
         botorch_model_class=self.botorch_model_class,
         mll_class=ExactMarginalLogLikelihood,
     )
     # Checking that model is None before `fit` (and `construct`) calls.
     self.assertIsNone(surrogate._model)
     # Should instantiate mll and `fit_gpytorch_model` when `state_dict`
     # is `None`.
     surrogate.fit(
         training_data=self.training_data,
         search_space_digest=self.search_space_digest,
         metric_names=self.metric_names,
         refit=self.refit,
     )
     # Check that training data is correctly passed through to the
     # BoTorch `Model`.
     self.assertTrue(
         torch.equal(
             surrogate.model.train_inputs[0],
             self.surrogate_kwargs.get("train_X"),
         )
     )
     self.assertTrue(
         torch.equal(
             surrogate.model.train_targets,
             self.surrogate_kwargs.get("train_Y").squeeze(1),
         )
     )
     mock_state_dict.assert_not_called()
     mock_MLL.assert_called_once()
     mock_fit_gpytorch.assert_called_once()
     mock_state_dict.reset_mock()
     mock_MLL.reset_mock()
     mock_fit_gpytorch.reset_mock()
     # Should `load_state_dict` when `state_dict` is not `None`
     # and `refit` is `False`.
     state_dict = {"state_attribute": "value"}
     surrogate.fit(
         training_data=self.training_data,
         search_space_digest=self.search_space_digest,
         metric_names=self.metric_names,
         refit=False,
         state_dict=state_dict,
     )
     mock_state_dict.assert_called_once()
     mock_MLL.assert_not_called()
     mock_fit_gpytorch.assert_not_called()