def test_joint_optimize( self, mock_get_best_candidates, mock_gen_candidates, mock_gen_batch_initial_conditions, cuda=False, ): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() tkwargs = { "device": torch.device("cuda") if cuda else torch.device("cpu") } cnt = 1 for dtype in (torch.float, torch.double): tkwargs["dtype"] = dtype mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, **tkwargs) mock_gen_candidates.return_value = torch.cat([ i * torch.ones(1, q, 3, **tkwargs) for i in range(num_restarts) ], dim=0) mock_get_best_candidates.return_value = torch.ones( 1, q, 3, **tkwargs) expected_candidates = mock_get_best_candidates.return_value bounds = torch.stack( [torch.zeros(3, **tkwargs), 4 * torch.ones(3, **tkwargs)]) candidates = joint_optimize( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, expected_candidates)) candidates = joint_optimize( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, return_best_only=False, batch_initial_conditions=torch.zeros(num_restarts, q, 3, **tkwargs), ) self.assertTrue( torch.equal(candidates, mock_gen_candidates.return_value[0])) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) cnt += 1
def test_joint_optimize( self, mock_get_best_candidates, mock_gen_candidates, mock_gen_batch_initial_conditions, cuda=False, ): q = 3 num_restarts = 2 raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() tkwargs = {"device": torch.device("cuda") if cuda else torch.device("cpu")} for dtype in (torch.float, torch.double): tkwargs["dtype"] = dtype mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, **tkwargs ) mock_gen_candidates.return_value = torch.cat( [i * torch.ones(1, q, 3, **tkwargs) for i in range(num_restarts)], dim=0 ) mock_get_best_candidates.return_value = torch.ones(1, q, 3, **tkwargs) expected_candidates = mock_get_best_candidates.return_value bounds = torch.stack( [torch.zeros(3, **tkwargs), 4 * torch.ones(3, **tkwargs)] ) candidates = joint_optimize( acq_function=mock_acq_function, bounds=bounds, q=q, num_restarts=num_restarts, raw_samples=raw_samples, options=options, ) self.assertTrue(torch.equal(candidates, expected_candidates))
def test_joint_optimize(self, mock_optimize_acqf): kwargs = { **self.shared_kwargs, "return_best_only": True, "batch_initial_conditions": None, } with warnings.catch_warnings(record=True) as ws: candidates, acq_values = joint_optimize(**kwargs) self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws)) self.assertTrue( any("joint_optimize is deprecated" in str(w.message) for w in ws) ) mock_optimize_acqf.assert_called_once_with(**kwargs, sequential=False) self.assertIsNone(candidates) self.assertIsNone(acq_values)