def test_optimize_acqf_list(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) inequality_constraints = [ [torch.tensor([3]), torch.tensor([4]), torch.tensor(5)] ] mock_acq_function_1 = MockAcquisitionFunction() mock_acq_function_2 = MockAcquisitionFunction() mock_acq_function_list = [mock_acq_function_1, mock_acq_function_2] for num_acqf, dtype in itertools.product([1, 2], (torch.float, torch.double)): inequality_constraints[0] = [ t.to(**tkwargs) for t in inequality_constraints[0] ] mock_optimize_acqf.reset_mock() tkwargs["dtype"] = dtype bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] gcs_return_vals = [ (torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs)) for _ in range(num_acqf) ] for rv in gcs_return_vals: candidate_rvs.append(rv[0]) acq_val_rvs.append(rv[1]) side_effect = list(zip(candidate_rvs, acq_val_rvs)) mock_optimize_acqf.side_effect = side_effect orig_candidates = candidate_rvs[0].clone() # Wrap the set_X_pending method for checking that call arguments with mock.patch.object( MockAcquisitionFunction, "set_X_pending", wraps=mock_acq_function_1.set_X_pending, ) as mock_set_X_pending_1, mock.patch.object( MockAcquisitionFunction, "set_X_pending", wraps=mock_acq_function_2.set_X_pending, ) as mock_set_X_pending_2: candidates, acq_values = optimize_acqf_list( acq_function_list=mock_acq_function_list[:num_acqf], bounds=bounds, num_restarts=num_restarts, raw_samples=raw_samples, options=options, inequality_constraints=inequality_constraints, post_processing_func=rounding_func, ) # check that X_pending is set correctly in sequential optimization if num_acqf > 1: x_pending_call_args_list = mock_set_X_pending_2.call_args_list idxr = torch.ones(num_acqf, dtype=torch.bool, device=self.device) for i in range(len(x_pending_call_args_list) - 1): idxr[i] = 0 self.assertTrue( torch.equal( x_pending_call_args_list[i][0][0], orig_candidates[idxr] ) ) idxr[i] = 1 orig_candidates[i] = candidate_rvs[i + 1] else: mock_set_X_pending_1.assert_not_called() # check final candidates expected_candidates = ( torch.cat(candidate_rvs[-num_acqf:], dim=0) if num_acqf > 1 else candidate_rvs[0] ) self.assertTrue(torch.equal(candidates, expected_candidates)) # check call arguments for optimize_acqf call_args_list = mock_optimize_acqf.call_args_list expected_call_args = { "acq_function": None, "bounds": bounds, "q": 1, "num_restarts": num_restarts, "raw_samples": raw_samples, "options": options, "inequality_constraints": inequality_constraints, "equality_constraints": None, "fixed_features": None, "post_processing_func": rounding_func, "batch_initial_conditions": None, "return_best_only": True, "sequential": False, } for i in range(len(call_args_list)): expected_call_args["acq_function"] = mock_acq_function_list[i] for k, v in call_args_list[i][1].items(): if torch.is_tensor(v): self.assertTrue(torch.equal(expected_call_args[k], v)) elif k == "acq_function": self.assertIsInstance( mock_acq_function_list[i], MockAcquisitionFunction ) else: self.assertEqual(expected_call_args[k], v)
def test_optimize_acqf_cyclic(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 num_cycles = 2 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) inequality_constraints = [ [torch.tensor([3]), torch.tensor([4]), torch.tensor(5)] ] mock_acq_function = MockAcquisitionFunction() for q, dtype in itertools.product([1, 3], (torch.float, torch.double)): inequality_constraints[0] = [ t.to(**tkwargs) for t in inequality_constraints[0] ] mock_optimize_acqf.reset_mock() tkwargs["dtype"] = dtype bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] for cycle_j in range(num_cycles): gcs_return_vals = [ (torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs)) for _ in range(q) ] if cycle_j == 0: # return `q` candidates for first call candidate_rvs.append( torch.cat([rv[0] for rv in gcs_return_vals], dim=-2) ) acq_val_rvs.append(torch.cat([rv[1] for rv in gcs_return_vals])) else: # return 1 candidate for subsequent calls for rv in gcs_return_vals: candidate_rvs.append(rv[0]) acq_val_rvs.append(rv[1]) mock_optimize_acqf.side_effect = list(zip(candidate_rvs, acq_val_rvs)) orig_candidates = candidate_rvs[0].clone() # wrap the set_X_pending method for checking that call arguments with mock.patch.object( MockAcquisitionFunction, "set_X_pending", wraps=mock_acq_function.set_X_pending, ) as mock_set_X_pending: candidates, acq_value = optimize_acqf_cyclic( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, inequality_constraints=inequality_constraints, post_processing_func=rounding_func, cyclic_options={"maxiter": num_cycles}, ) # check that X_pending is set correctly in cyclic optimization if q > 1: x_pending_call_args_list = mock_set_X_pending.call_args_list idxr = torch.ones(q, dtype=torch.bool, device=self.device) for i in range(len(x_pending_call_args_list) - 1): idxr[i] = 0 self.assertTrue( torch.equal( x_pending_call_args_list[i][0][0], orig_candidates[idxr] ) ) idxr[i] = 1 orig_candidates[i] = candidate_rvs[i + 1] # check reset to base_X_pendingg self.assertIsNone(x_pending_call_args_list[-1][0][0]) else: mock_set_X_pending.assert_not_called() # check final candidates expected_candidates = ( torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0] ) self.assertTrue(torch.equal(candidates, expected_candidates)) # check call arguments for optimize_acqf call_args_list = mock_optimize_acqf.call_args_list expected_call_args = { "acq_function": mock_acq_function, "bounds": bounds, "num_restarts": num_restarts, "raw_samples": raw_samples, "options": options, "inequality_constraints": inequality_constraints, "equality_constraints": None, "fixed_features": None, "post_processing_func": rounding_func, "return_best_only": True, "sequential": True, } orig_candidates = candidate_rvs[0].clone() for i in range(len(call_args_list)): if i == 0: # first cycle expected_call_args.update( {"batch_initial_conditions": None, "q": q} ) else: expected_call_args.update( {"batch_initial_conditions": orig_candidates[i - 1 : i], "q": 1} ) orig_candidates[i - 1] = candidate_rvs[i] for k, v in call_args_list[i][1].items(): if torch.is_tensor(v): self.assertTrue(torch.equal(expected_call_args[k], v)) elif k == "acq_function": self.assertIsInstance( mock_acq_function, MockAcquisitionFunction ) else: self.assertEqual(expected_call_args[k], v)
def test_gen_batch_initial_conditions_constraints(self): for dtype in (torch.float, torch.double): bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) inequality_constraints = [( torch.tensor([1], device=self.device, dtype=dtype), torch.tensor([-4], device=self.device, dtype=dtype), torch.tensor(-3, device=self.device, dtype=dtype), )] equality_constraints = [( torch.tensor([0], device=self.device, dtype=dtype), torch.tensor([1], device=self.device, dtype=dtype), torch.tensor(0.5, device=self.device, dtype=dtype), )] for nonnegative in (True, False): for seed in (None, 1234): mock_acqf = MockAcquisitionFunction() for init_batch_limit in (None, 1): mock_acqf = MockAcquisitionFunction() with mock.patch.object( MockAcquisitionFunction, "__call__", wraps=mock_acqf.__call__, ) as mock_acqf_call: batch_initial_conditions = gen_batch_initial_conditions( acq_function=mock_acqf, bounds=bounds, q=1, num_restarts=2, raw_samples=10, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "init_batch_limit": init_batch_limit, "thinning": 2, "n_burnin": 3, }, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, ) expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) batch_shape = (torch.Size([]) if init_batch_limit is None else torch.Size([init_batch_limit])) raw_samps = mock_acqf_call.call_args[0][0] batch_shape = (torch.Size([10]) if init_batch_limit is None else torch.Size([init_batch_limit])) expected_raw_samps_shape = batch_shape + torch.Size( [1, 2]) self.assertEqual(raw_samps.shape, expected_raw_samps_shape) self.assertTrue((raw_samps[..., 0] == 0.5).all()) self.assertTrue( (-4 * raw_samps[..., 1] >= -3).all())
def test_optimize_acqf_joint( self, mock_gen_candidates, mock_gen_batch_initial_conditions ): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() cnt = 1 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ) base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0 ) mock_acq_values = num_restarts - torch.arange( num_restarts, device=self.device, dtype=dtype ) mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) bounds = torch.stack( [ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ] ) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ), ) self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) cnt += 1 # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue( torch.equal( candidates, mock_acq_function.extract_candidates(mock_candidates[0]) ) ) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
def test_optimize_acqf_joint(self, mock_gen_candidates, mock_gen_batch_initial_conditions): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() cnt = 0 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype) base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0) mock_acq_values = num_restarts - torch.arange( num_restarts, device=self.device, dtype=dtype) mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) bounds = torch.stack([ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ]) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test generation with provided initial conditions candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros(num_restarts, q, 3, device=self.device, dtype=dtype), ) self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test fixed features fixed_features = {0: 0.1} mock_candidates[:, 0] = 0.1 mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, fixed_features=fixed_features, ) self.assertEqual( mock_gen_candidates.call_args[1]["fixed_features"], fixed_features) self.assertTrue(torch.equal(candidates, mock_candidates[0])) cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test trivial case when all features are fixed candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, fixed_features={ 0: 0.1, 1: 0.2, 2: 0.3 }, ) self.assertTrue( torch.equal( candidates, torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype).expand(3, 3), )) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue( torch.equal( candidates, mock_acq_function.extract_candidates(mock_candidates[0]))) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
def test_optimize_acqf_mixed(self, mock_optimize_acqf): num_restarts = 2 raw_samples = 10 q = 1 options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) mock_acq_function = MockAcquisitionFunction() for num_ff, dtype in itertools.product([1, 3], (torch.float, torch.double)): tkwargs["dtype"] = dtype mock_optimize_acqf.reset_mock() bounds = bounds.to(**tkwargs) candidate_rvs = [] acq_val_rvs = [] gcs_return_vals = [(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs)) for _ in range(num_ff)] for rv in gcs_return_vals: candidate_rvs.append(rv[0]) acq_val_rvs.append(rv[1]) fixed_features_list = [{i: i * 0.1} for i in range(num_ff)] side_effect = list(zip(candidate_rvs, acq_val_rvs)) mock_optimize_acqf.side_effect = side_effect candidates, acq_value = optimize_acqf_mixed( acq_function=mock_acq_function, q=q, fixed_features_list=fixed_features_list, bounds=bounds, num_restarts=num_restarts, raw_samples=raw_samples, options=options, post_processing_func=rounding_func, ) # compute expected output ff_acq_values = torch.stack(acq_val_rvs) best = torch.argmax(ff_acq_values) expected_candidates = candidate_rvs[best] expected_acq_value = ff_acq_values[best] self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue(torch.equal(acq_value, expected_acq_value)) # check call arguments for optimize_acqf call_args_list = mock_optimize_acqf.call_args_list expected_call_args = { "acq_function": None, "bounds": bounds, "q": q, "num_restarts": num_restarts, "raw_samples": raw_samples, "options": options, "inequality_constraints": None, "equality_constraints": None, "fixed_features": None, "post_processing_func": rounding_func, "batch_initial_conditions": None, "return_best_only": True, "sequential": False, } for i in range(len(call_args_list)): expected_call_args["fixed_features"] = fixed_features_list[i] for k, v in call_args_list[i][1].items(): if torch.is_tensor(v): self.assertTrue(torch.equal(expected_call_args[k], v)) elif k == "acq_function": self.assertIsInstance(v, MockAcquisitionFunction) else: self.assertEqual(expected_call_args[k], v)