def test_get_next_by_local_search(self, patch): # Without known incumbent class SideEffect(object): def __init__(self): self.call_number = 0 def __call__(self, *args, **kwargs): rval = 9 - self.call_number self.call_number += 1 return (ConfigurationMock(rval), [[rval]]) patch.side_effect = SideEffect() smbo = SMBO(self.scenario, 1) rval = smbo._get_next_by_local_search(num_points=9) self.assertEqual(len(rval), 9) self.assertEqual(patch.call_count, 9) for i in range(9): self.assertIsInstance(rval[i][1], ConfigurationMock) self.assertEqual(rval[i][1].value, 9 - i) self.assertEqual(rval[i][0], 9 - i) self.assertEqual(rval[i][1].origin, 'Local Search') # With known incumbent patch.side_effect = SideEffect() smbo.incumbent = 'Incumbent' rval = smbo._get_next_by_local_search(num_points=10) self.assertEqual(len(rval), 10) self.assertEqual(patch.call_count, 19) # Only the first local search in each iteration starts from the # incumbent self.assertEqual(patch.call_args_list[9][0][0], 'Incumbent') for i in range(10): self.assertEqual(rval[i][1].origin, 'Local Search')
def test_choose_next_2(self): def side_effect(X, derivative): return np.mean(X, axis=1).reshape((-1, 1)) smbo = SMBO(self.scenario, 1) smbo.runhistory = RunHistory() smbo.model = mock.MagicMock() smbo.acquisition_func._compute = mock.MagicMock() smbo.acquisition_func._compute.side_effect = side_effect # local search would call the underlying local search maximizer, # which would have to be mocked out. Replacing the method by random # search is way easier! smbo._get_next_by_local_search = smbo._get_next_by_random_search X = smbo.rng.rand(10, 2) Y = smbo.rng.rand(10, 1) x = smbo.choose_next(X, Y) self.assertEqual(smbo.model.train.call_count, 1) self.assertEqual(smbo.acquisition_func._compute.call_count, 1) self.assertEqual(len(x), 2020) num_random_search = 0 for i in range(0, 2020, 2): self.assertIsInstance(x[i], Configuration) if x[i].origin == 'Random Search': num_random_search += 1 # Since we replace local search with random search, we have to count # the occurences of random seacrh instead self.assertEqual(num_random_search, 10) for i in range(1, 2020, 2): self.assertIsInstance(x[i], Configuration) self.assertEqual(x[i].origin, 'Random Search')