def test_optimize_acqf_joint(self, mock_gen_candidates, mock_gen_batch_initial_conditions): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() cnt = 1 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype) base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0) mock_acq_values = num_restarts - torch.arange( num_restarts, device=self.device, dtype=dtype) mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) bounds = torch.stack([ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ]) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros(num_restarts, q, 3, device=self.device, dtype=dtype), ) self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) cnt += 1 # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue( torch.equal( candidates, mock_acq_function.extract_candidates(mock_candidates[0]))) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
def test_optimize_acqf_joint( self, mock_gen_candidates, mock_gen_batch_initial_conditions ): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() cnt = 0 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ) base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0 ) mock_acq_values = num_restarts - torch.arange( num_restarts, device=self.device, dtype=dtype ) mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) bounds = torch.stack( [ torch.zeros(3, device=self.device, dtype=dtype), 4 * torch.ones(3, device=self.device, dtype=dtype), ] ) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test generation with provided initial conditions candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ), ) self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test fixed features fixed_features = {0: 0.1} mock_candidates[:, 0] = 0.1 mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, fixed_features=fixed_features, ) self.assertEqual( mock_gen_candidates.call_args[1]["fixed_features"], fixed_features ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test trivial case when all features are fixed candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, fixed_features={0: 0.1, 1: 0.2, 2: 0.3}, ) self.assertTrue( torch.equal( candidates, torch.tensor( [0.1, 0.2, 0.3], device=self.device, dtype=dtype ).expand(3, 3), ) ) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue( torch.equal( candidates, mock_acq_function.extract_candidates(mock_candidates[0]) ) ) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))