def test_gen_batch_initial_conditions(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) for dtype in (torch.float, torch.double): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] ): with mock.patch.object( MockAcquisitionFunction, "__call__", wraps=mock_acqf.__call__, ) as mock_acqf_call: batch_initial_conditions = gen_batch_initial_conditions( acq_function=mock_acqf, bounds=bounds, q=1, num_restarts=2, raw_samples=10, fixed_features=ffs, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "init_batch_limit": init_batch_limit, "sample_around_best": sample_around_best, }, ) expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) batch_shape = ( torch.Size([]) if init_batch_limit is None else torch.Size([init_batch_limit]) ) raw_samps = mock_acqf_call.call_args[0][0] batch_shape = ( torch.Size([20 if sample_around_best else 10]) if init_batch_limit is None else torch.Size([init_batch_limit]) ) expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) self.assertEqual(raw_samps.shape, expected_raw_samps_shape) if ffs is not None: for idx, val in ffs.items(): self.assertTrue( torch.all(batch_initial_conditions[..., idx] == val) )
def test_gen_batch_initial_conditions(self): for dtype in (torch.float, torch.double): bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) for nonnegative in (True, False): for seed in (None, 1234): batch_initial_conditions = gen_batch_initial_conditions( acq_function=MockAcquisitionFunction(), bounds=bounds, q=1, num_restarts=2, raw_samples=10, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, }, ) expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
def test_gen_batch_initial_conditions_highdim(self): d = 120 bounds = torch.stack([torch.zeros(d), torch.ones(d)]) for dtype in (torch.float, torch.double): bounds = bounds.to(device=self.device, dtype=dtype) for nonnegative in (True, False): for seed in (None, 1234): with warnings.catch_warnings( record=True) as ws, settings.debug(True): batch_initial_conditions = gen_batch_initial_conditions( acq_function=MockAcquisitionFunction(), bounds=bounds, q=10, num_restarts=1, raw_samples=2, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, }, ) self.assertTrue( any( issubclass(w.category, SamplingWarning) for w in ws)) expected_shape = torch.Size([1, 10, d]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
def test_gen_batch_initial_conditions_warning(self): for dtype in (torch.float, torch.double): bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) samples = torch.zeros(10, 1, 2, device=self.device, dtype=dtype) with ExitStack() as es: ws = es.enter_context(warnings.catch_warnings(record=True)) es.enter_context(settings.debug(True)) es.enter_context( mock.patch( "botorch.optim.initializers.draw_sobol_samples", return_value=samples, )) batch_initial_conditions = gen_batch_initial_conditions( acq_function=MockAcquisitionFunction(), bounds=bounds, q=1, num_restarts=2, raw_samples=10, options={"seed": 1234}, ) self.assertEqual(len(ws), 1) self.assertTrue( any( issubclass(w.category, BadInitialCandidatesWarning) for w in ws)) self.assertTrue( torch.equal( batch_initial_conditions, torch.zeros(2, 1, 2, device=self.device, dtype=dtype), ))
def test_gen_batch_initial_conditions(self): for dtype in (torch.float, torch.double): bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) for nonnegative in (True, False): for seed in (None, 1234): mock_acqf = MockAcquisitionFunction() for init_batch_limit in (None, 1): mock_acqf = MockAcquisitionFunction() with mock.patch.object( MockAcquisitionFunction, "__call__", wraps=mock_acqf.__call__, ) as mock_acqf_call: batch_initial_conditions = gen_batch_initial_conditions( acq_function=mock_acqf, bounds=bounds, q=1, num_restarts=2, raw_samples=10, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "init_batch_limit": init_batch_limit, }, ) expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) batch_shape = (torch.Size([]) if init_batch_limit is None else torch.Size([init_batch_limit])) raw_samps = mock_acqf_call.call_args[0][0] batch_shape = (torch.Size([10]) if init_batch_limit is None else torch.Size([init_batch_limit])) expected_raw_samps_shape = batch_shape + torch.Size( [1, 2]) self.assertEqual(raw_samps.shape, expected_raw_samps_shape)
def test_optimize_acqf_sequential( self, mock_gen_candidates_scipy, mock_gen_batch_initial_conditions ): q = 3 num_restarts = 2 raw_samples = 10 options = {} for dtype in (torch.float, torch.double): mock_acq_function = MockAcquisitionFunction() mock_gen_batch_initial_conditions.side_effect = [ torch.zeros(num_restarts, device=self.device, dtype=dtype) for _ in range(q) ] gcs_return_vals = [ ( torch.tensor([[[1.1, 2.1, 3.1]]], device=self.device, dtype=dtype), torch.tensor([i], device=self.device, dtype=dtype), ) for i in range(q) ] mock_gen_candidates_scipy.side_effect = gcs_return_vals expected_candidates = torch.cat( [rv[0][0] for rv in gcs_return_vals], dim=-2 ).round() bounds = torch.stack( [ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ] ) inequality_constraints = [ (torch.tensor([3]), torch.tensor([4]), torch.tensor(5)) ] candidates, acq_value = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, inequality_constraints=inequality_constraints, post_processing_func=rounding_func, sequential=True, ) self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue( torch.equal(acq_value, torch.cat([rv[1] for rv in gcs_return_vals])) ) # verify error when using a OneShotAcquisitionFunction with self.assertRaises(NotImplementedError): optimize_acqf( acq_function=mock.Mock(spec=OneShotAcquisitionFunction), bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, sequential=True, )
def test_optimize_acqf_sequential_notimplemented(self): with self.assertRaises(NotImplementedError): optimize_acqf( acq_function=MockAcquisitionFunction(), bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]), q=3, num_restarts=2, raw_samples=10, return_best_only=False, sequential=True, )
def test_optimize_acqf_empty_ff(self): with self.assertRaises(ValueError): mock_acq_function = MockAcquisitionFunction() optimize_acqf_mixed( acq_function=mock_acq_function, q=1, fixed_features_list=[], bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]), num_restarts=2, raw_samples=10, )
class TestDeprecatedOptimize(BotorchTestCase): shared_kwargs = { "acq_function": MockAcquisitionFunction(), "bounds": torch.zeros(2, 2), "q": 3, "num_restarts": 2, "raw_samples": 10, "options": {}, "inequality_constraints": None, "equality_constraints": None, "fixed_features": None, "post_processing_func": None, } @mock.patch("botorch.optim.optimize.optimize_acqf", return_value=(None, None)) def test_joint_optimize(self, mock_optimize_acqf): kwargs = { **self.shared_kwargs, "return_best_only": True, "batch_initial_conditions": None, } with warnings.catch_warnings(record=True) as ws: candidates, acq_values = joint_optimize(**kwargs) self.assertTrue( any(issubclass(w.category, DeprecationWarning) for w in ws)) self.assertTrue( any("joint_optimize is deprecated" in str(w.message) for w in ws)) mock_optimize_acqf.assert_called_once_with(**kwargs, sequential=False) self.assertIsNone(candidates) self.assertIsNone(acq_values) @mock.patch("botorch.optim.optimize.optimize_acqf", return_value=(None, None)) def test_sequential_optimize(self, mock_optimize_acqf): with warnings.catch_warnings(record=True) as ws: candidates, acq_values = sequential_optimize(**self.shared_kwargs) self.assertTrue( any(issubclass(w.category, DeprecationWarning) for w in ws)) self.assertTrue( any("sequential_optimize is deprecated" in str(w.message) for w in ws)) mock_optimize_acqf.assert_called_once_with( **self.shared_kwargs, return_best_only=True, sequential=True, batch_initial_conditions=None, ) self.assertIsNone(candidates) self.assertIsNone(acq_values)
def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 q = 2 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) mock_acq_functions = [ MockAcquisitionFunction(), MockOneShotEvaluateAcquisitionFunction(), ] for num_ff, dtype, mock_acq_function in itertools.product( [1, 3], (torch.float, torch.double), mock_acq_functions ): tkwargs["dtype"] = dtype mock_optimize_acqf.reset_mock() bounds = bounds.to(**tkwargs) fixed_features_list = [{i: i * 0.1} for i in range(num_ff)] candidate_rvs, exp_candidates, acq_val_rvs = [], [], [] # generate mock side effects and compute expected outputs for _ in range(q): candidate_rvs_q = [torch.rand(1, 3, **tkwargs) for _ in range(num_ff)] acq_val_rvs_q = [torch.rand(1, **tkwargs) for _ in range(num_ff)] best = torch.argmax(torch.stack(acq_val_rvs_q)) exp_candidates.append(candidate_rvs_q[best]) candidate_rvs += candidate_rvs_q acq_val_rvs += acq_val_rvs_q side_effect = list(zip(candidate_rvs, acq_val_rvs)) mock_optimize_acqf.side_effect = side_effect candidates, acq_value = optimize_acqf_mixed( acq_function=mock_acq_function, q=q, fixed_features_list=fixed_features_list, bounds=bounds, num_restarts=num_restarts, raw_samples=raw_samples, options=options, post_processing_func=rounding_func, ) expected_candidates = torch.cat(exp_candidates, dim=-2) if isinstance(mock_acq_function, MockOneShotEvaluateAcquisitionFunction): expected_acq_value = mock_acq_function.evaluate( expected_candidates, bounds=bounds ) else: expected_acq_value = mock_acq_function(expected_candidates) self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue(torch.equal(acq_value, expected_acq_value))
def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) bounds = torch.stack([torch.zeros(d), torch.ones(d)]) ffs_map = {i: random() for i in range(0, d, 2)} mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) for dtype in (torch.float, torch.double): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) for nonnegative, seed, ffs, sample_around_best in product( [True, False], [None, 1234], [None, ffs_map], [True, False]): with warnings.catch_warnings( record=True) as ws, settings.debug(True): batch_initial_conditions = gen_batch_initial_conditions( acq_function=MockAcquisitionFunction(), bounds=bounds, q=10, num_restarts=1, raw_samples=2, fixed_features=ffs, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "sample_around_best": sample_around_best, }, ) self.assertTrue( any( issubclass(w.category, SamplingWarning) for w in ws)) expected_shape = torch.Size([1, 10, d]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) if ffs is not None: for idx, val in ffs.items(): self.assertTrue( torch.all(batch_initial_conditions[..., idx] == val))
def test_gen_batch_initial_conditions_constraints(self): for dtype in (torch.float, torch.double): bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) inequality_constraints = [( torch.tensor([1], device=self.device, dtype=dtype), torch.tensor([-4], device=self.device, dtype=dtype), torch.tensor(-3, device=self.device, dtype=dtype), )] equality_constraints = [( torch.tensor([0], device=self.device, dtype=dtype), torch.tensor([1], device=self.device, dtype=dtype), torch.tensor(0.5, device=self.device, dtype=dtype), )] for nonnegative in (True, False): for seed in (None, 1234): mock_acqf = MockAcquisitionFunction() for init_batch_limit in (None, 1): mock_acqf = MockAcquisitionFunction() with mock.patch.object( MockAcquisitionFunction, "__call__", wraps=mock_acqf.__call__, ) as mock_acqf_call: batch_initial_conditions = gen_batch_initial_conditions( acq_function=mock_acqf, bounds=bounds, q=1, num_restarts=2, raw_samples=10, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "init_batch_limit": init_batch_limit, "thinning": 2, "n_burnin": 3, }, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) batch_shape = (torch.Size([]) if init_batch_limit is None else torch.Size([init_batch_limit])) raw_samps = mock_acqf_call.call_args[0][0] batch_shape = (torch.Size([10]) if init_batch_limit is None else torch.Size([init_batch_limit])) expected_raw_samps_shape = batch_shape + torch.Size( [1, 2]) self.assertEqual(raw_samps.shape, expected_raw_samps_shape) self.assertTrue((raw_samps[..., 0] == 0.5).all()) self.assertTrue( (-4 * raw_samps[..., 1] >= -3).all())
def test_optimize_acqf_joint(self, mock_gen_candidates, mock_gen_batch_initial_conditions): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() cnt = 1 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype) base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0) mock_acq_values = num_restarts - torch.arange( num_restarts, device=self.device, dtype=dtype) mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) bounds = torch.stack([ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ]) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros(num_restarts, q, 3, device=self.device, dtype=dtype), ) self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) cnt += 1 # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue( torch.equal( candidates, mock_acq_function.extract_candidates(mock_candidates[0]))) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
def test_optimize_acqf_mixed(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 q = 1 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) mock_acq_function = MockAcquisitionFunction() for num_ff, dtype in itertools.product([1, 3], (torch.float, torch.double)): tkwargs["dtype"] = dtype mock_optimize_acqf.reset_mock() bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] gcs_return_vals = [(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs)) for _ in range(num_ff)] for rv in gcs_return_vals: candidate_rvs.append(rv[0]) acq_val_rvs.append(rv[1]) fixed_features_list = [{i: i * 0.1} for i in range(num_ff)] side_effect = list(zip(candidate_rvs, acq_val_rvs)) mock_optimize_acqf.side_effect = side_effect candidates, acq_value = optimize_acqf_mixed( acq_function=mock_acq_function, q=q, fixed_features_list=fixed_features_list, bounds=bounds, num_restarts=num_restarts, raw_samples=raw_samples, options=options, post_processing_func=rounding_func, ) # compute expected output ff_acq_values = torch.stack(acq_val_rvs) best = torch.argmax(ff_acq_values) expected_candidates = candidate_rvs[best] expected_acq_value = ff_acq_values[best] self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue(torch.equal(acq_value, expected_acq_value)) # check call arguments for optimize_acqf call_args_list = mock_optimize_acqf.call_args_list expected_call_args = { "acq_function": None, "bounds": bounds, "q": q, "num_restarts": num_restarts, "raw_samples": raw_samples, "options": options, "inequality_constraints": None, "equality_constraints": None, "fixed_features": None, "post_processing_func": rounding_func, "batch_initial_conditions": None, "return_best_only": True, "sequential": False, } for i in range(len(call_args_list)): expected_call_args["fixed_features"] = fixed_features_list[i] for k, v in call_args_list[i][1].items(): if torch.is_tensor(v): self.assertTrue(torch.equal(expected_call_args[k], v)) elif k == "acq_function": self.assertIsInstance(v, MockAcquisitionFunction) else: self.assertEqual(expected_call_args[k], v)
def test_optimize_acqf_list(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) inequality_constraints = [[ torch.tensor([3]), torch.tensor([4]), torch.tensor(5) ]] # reinitialize so that dtype mock_acq_function_1 = MockAcquisitionFunction() mock_acq_function_2 = MockAcquisitionFunction() mock_acq_function_list = [mock_acq_function_1, mock_acq_function_2] for num_acqf, dtype in itertools.product([1, 2], (torch.float, torch.double)): for m in mock_acq_function_list: # clear previous X_pending m.set_X_pending(None) tkwargs["dtype"] = dtype inequality_constraints[0] = [ t.to(**tkwargs) for t in inequality_constraints[0] ] mock_optimize_acqf.reset_mock() bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] gcs_return_vals = [(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs)) for _ in range(num_acqf)] for rv in gcs_return_vals: candidate_rvs.append(rv[0]) acq_val_rvs.append(rv[1]) side_effect = list(zip(candidate_rvs, acq_val_rvs)) mock_optimize_acqf.side_effect = side_effect orig_candidates = candidate_rvs[0].clone() # Wrap the set_X_pending method for checking that call arguments with mock.patch.object( MockAcquisitionFunction, "set_X_pending", wraps=mock_acq_function_1.set_X_pending, ) as mock_set_X_pending_1, mock.patch.object( MockAcquisitionFunction, "set_X_pending", wraps=mock_acq_function_2.set_X_pending, ) as mock_set_X_pending_2: candidates, acq_values = optimize_acqf_list( acq_function_list=mock_acq_function_list[:num_acqf], bounds=bounds, num_restarts=num_restarts, raw_samples=raw_samples, options=options, inequality_constraints=inequality_constraints, post_processing_func=rounding_func, ) # check that X_pending is set correctly in sequential optimization if num_acqf > 1: x_pending_call_args_list = mock_set_X_pending_2.call_args_list idxr = torch.ones(num_acqf, dtype=torch.bool, device=self.device) for i in range(len(x_pending_call_args_list) - 1): idxr[i] = 0 self.assertTrue( torch.equal(x_pending_call_args_list[i][0][0], orig_candidates[idxr])) idxr[i] = 1 orig_candidates[i] = candidate_rvs[i + 1] else: mock_set_X_pending_1.assert_not_called() # check final candidates expected_candidates = (torch.cat(candidate_rvs[-num_acqf:], dim=0) if num_acqf > 1 else candidate_rvs[0]) self.assertTrue(torch.equal(candidates, expected_candidates)) # check call arguments for optimize_acqf call_args_list = mock_optimize_acqf.call_args_list expected_call_args = { "acq_function": None, "bounds": bounds, "q": 1, "num_restarts": num_restarts, "raw_samples": raw_samples, "options": options, "inequality_constraints": inequality_constraints, "equality_constraints": None, "fixed_features": None, "post_processing_func": rounding_func, "batch_initial_conditions": None, "return_best_only": True, "sequential": False, } for i in range(len(call_args_list)): expected_call_args["acq_function"] = mock_acq_function_list[i] for k, v in call_args_list[i][1].items(): if torch.is_tensor(v): self.assertTrue(torch.equal(expected_call_args[k], v)) elif k == "acq_function": self.assertIsInstance(mock_acq_function_list[i], MockAcquisitionFunction) else: self.assertEqual(expected_call_args[k], v)
def test_optimize_acqf_cyclic(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 num_cycles = 2 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) inequality_constraints = [[ torch.tensor([3]), torch.tensor([4]), torch.tensor(5) ]] mock_acq_function = MockAcquisitionFunction() for q, dtype in itertools.product([1, 3], (torch.float, torch.double)): inequality_constraints[0] = [ t.to(**tkwargs) for t in inequality_constraints[0] ] mock_optimize_acqf.reset_mock() tkwargs["dtype"] = dtype bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] for cycle_j in range(num_cycles): gcs_return_vals = [(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs)) for _ in range(q)] if cycle_j == 0: # return `q` candidates for first call candidate_rvs.append( torch.cat([rv[0] for rv in gcs_return_vals], dim=-2)) acq_val_rvs.append( torch.cat([rv[1] for rv in gcs_return_vals])) else: # return 1 candidate for subsequent calls for rv in gcs_return_vals: candidate_rvs.append(rv[0]) acq_val_rvs.append(rv[1]) mock_optimize_acqf.side_effect = list( zip(candidate_rvs, acq_val_rvs)) orig_candidates = candidate_rvs[0].clone() # wrap the set_X_pending method for checking that call arguments with mock.patch.object( MockAcquisitionFunction, "set_X_pending", wraps=mock_acq_function.set_X_pending, ) as mock_set_X_pending: candidates, acq_value = optimize_acqf_cyclic( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, inequality_constraints=inequality_constraints, post_processing_func=rounding_func, cyclic_options={"maxiter": num_cycles}, ) # check that X_pending is set correctly in cyclic optimization if q > 1: x_pending_call_args_list = mock_set_X_pending.call_args_list idxr = torch.ones(q, dtype=torch.bool, device=self.device) for i in range(len(x_pending_call_args_list) - 1): idxr[i] = 0 self.assertTrue( torch.equal(x_pending_call_args_list[i][0][0], orig_candidates[idxr])) idxr[i] = 1 orig_candidates[i] = candidate_rvs[i + 1] # check reset to base_X_pendingg self.assertIsNone(x_pending_call_args_list[-1][0][0]) else: mock_set_X_pending.assert_not_called() # check final candidates expected_candidates = (torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0]) self.assertTrue(torch.equal(candidates, expected_candidates)) # check call arguments for optimize_acqf call_args_list = mock_optimize_acqf.call_args_list expected_call_args = { "acq_function": mock_acq_function, "bounds": bounds, "num_restarts": num_restarts, "raw_samples": raw_samples, "options": options, "inequality_constraints": inequality_constraints, "equality_constraints": None, "fixed_features": None, "post_processing_func": rounding_func, "return_best_only": True, "sequential": True, } orig_candidates = candidate_rvs[0].clone() for i in range(len(call_args_list)): if i == 0: # first cycle expected_call_args.update({ "batch_initial_conditions": None, "q": q }) else: expected_call_args.update({ "batch_initial_conditions": orig_candidates[i - 1:i], "q": 1 }) orig_candidates[i - 1] = candidate_rvs[i] for k, v in call_args_list[i][1].items(): if torch.is_tensor(v): self.assertTrue(torch.equal(expected_call_args[k], v)) elif k == "acq_function": self.assertIsInstance(mock_acq_function, MockAcquisitionFunction) else: self.assertEqual(expected_call_args[k], v)
def test_remove_fixed_features_from_optimization(self): fixed_features = {1: 1.0, 3: -1.0} b, q, d = 7, 3, 5 initial_conditions = torch.randn(b, q, d, device=self.device) tensor_lower_bounds = torch.randn(q, d, device=self.device) tensor_upper_bounds = tensor_lower_bounds + torch.rand( q, d, device=self.device) old_inequality_constraints = [( torch.arange(0, 5, 2, device=self.device), torch.rand(3, device=self.device), 1.0, )] old_equality_constraints = [( torch.arange(0, 3, 1, device=self.device), torch.rand(3, device=self.device), 1.0, )] acqf = MockAcquisitionFunction() def check_bounds_and_init(old_val, new_val): if isinstance(old_val, float): self.assertEqual(old_val, new_val) elif isinstance(old_val, torch.Tensor): mask = [(i not in fixed_features.keys()) for i in range(d)] self.assertTrue(torch.equal(old_val[..., mask], new_val)) else: self.assertIsNone(new_val) def check_cons(old_cons, new_cons): if old_cons: # we don't fixed all dimensions in this test new_dim = d - len(fixed_features) self.assertTrue( torch.all((new_cons[0][0] <= new_dim) & (new_cons[0][0] >= 0))) else: self.assertEqual(old_cons, new_cons) for ( lower_bounds, upper_bounds, inequality_constraints, equality_constraints, ) in product( [None, -1.0, tensor_lower_bounds], [None, 1.0, tensor_upper_bounds], [None, old_inequality_constraints], [None, old_equality_constraints], ): _no_ff = _remove_fixed_features_from_optimization( fixed_features=fixed_features, acquisition_function=acqf, initial_conditions=initial_conditions, lower_bounds=lower_bounds, upper_bounds=upper_bounds, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) self.assertIsInstance(_no_ff.acquisition_function, FixedFeatureAcquisitionFunction) check_bounds_and_init(initial_conditions, _no_ff.initial_conditions) check_bounds_and_init(lower_bounds, _no_ff.lower_bounds) check_bounds_and_init(upper_bounds, _no_ff.upper_bounds) check_cons(inequality_constraints, _no_ff.inequality_constraints) check_cons(equality_constraints, _no_ff.equality_constraints)
def test_optimize_acqf_joint( self, mock_gen_candidates, mock_gen_batch_initial_conditions ): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() cnt = 0 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ) base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0 ) mock_acq_values = num_restarts - torch.arange( num_restarts, device=self.device, dtype=dtype ) mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) bounds = torch.stack( [ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ] ) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test generation with provided initial conditions candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ), ) self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test fixed features fixed_features = {0: 0.1} mock_candidates[:, 0] = 0.1 mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, fixed_features=fixed_features, ) self.assertEqual( mock_gen_candidates.call_args[1]["fixed_features"], fixed_features ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test trivial case when all features are fixed candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, fixed_features={0: 0.1, 1: 0.2, 2: 0.3}, ) self.assertTrue( torch.equal( candidates, torch.tensor( [0.1, 0.2, 0.3], device=self.device, dtype=dtype ).expand(3, 3), ) ) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue( torch.equal( candidates, mock_acq_function.extract_candidates(mock_candidates[0]) ) ) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))