def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict): surrogate = Surrogate( botorch_model_class=self.botorch_model_class, mll_class=ExactMarginalLogLikelihood, ) # Checking that model is None before `fit` (and `construct`) calls. self.assertIsNone(surrogate._model) # Should instantiate mll and `fit_gpytorch_model` when `state_dict` # is `None`. surrogate.fit( training_data=self.training_data, search_space_digest=self.search_space_digest, metric_names=self.metric_names, refit=self.refit, ) mock_state_dict.assert_not_called() mock_MLL.assert_called_once() mock_fit_gpytorch.assert_called_once() mock_state_dict.reset_mock() mock_MLL.reset_mock() mock_fit_gpytorch.reset_mock() # Should `load_state_dict` when `state_dict` is not `None` # and `refit` is `False`. state_dict = {"state_attribute": "value"} surrogate.fit( training_data=self.training_data, search_space_digest=self.search_space_digest, metric_names=self.metric_names, refit=False, state_dict=state_dict, ) mock_state_dict.assert_called_once() mock_MLL.assert_not_called() mock_fit_gpytorch.assert_not_called()
def test_mll_options(self, _): mock_mll = MagicMock(self.mll_class) surrogate = Surrogate( botorch_model_class=self.botorch_model_class, mll_class=mock_mll, mll_options={"some_option": "some_value"}, ) surrogate.fit( training_data=self.training_data, search_space_digest=self.search_space_digest, metric_names=self.metric_names, refit=self.refit, ) self.assertEqual(mock_mll.call_args[1]["some_option"], "some_value")
def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict): surrogate = Surrogate( botorch_model_class=self.botorch_model_class, mll_class=ExactMarginalLogLikelihood, ) # Checking that model is None before `fit` (and `construct`) calls. self.assertIsNone(surrogate._model) # Should instantiate mll and `fit_gpytorch_model` when `state_dict` # is `None`. surrogate.fit( training_data=self.training_data, search_space_digest=self.search_space_digest, metric_names=self.metric_names, refit=self.refit, ) # Check that training data is correctly passed through to the # BoTorch `Model`. self.assertTrue( torch.equal( surrogate.model.train_inputs[0], self.surrogate_kwargs.get("train_X"), ) ) self.assertTrue( torch.equal( surrogate.model.train_targets, self.surrogate_kwargs.get("train_Y").squeeze(1), ) ) mock_state_dict.assert_not_called() mock_MLL.assert_called_once() mock_fit_gpytorch.assert_called_once() mock_state_dict.reset_mock() mock_MLL.reset_mock() mock_fit_gpytorch.reset_mock() # Should `load_state_dict` when `state_dict` is not `None` # and `refit` is `False`. state_dict = {"state_attribute": "value"} surrogate.fit( training_data=self.training_data, search_space_digest=self.search_space_digest, metric_names=self.metric_names, refit=False, state_dict=state_dict, ) mock_state_dict.assert_called_once() mock_MLL.assert_not_called() mock_fit_gpytorch.assert_not_called()