Beispiel #1
0
    def test_joint_optimize(
        self,
        mock_get_best_candidates,
        mock_gen_candidates,
        mock_gen_batch_initial_conditions,
        cuda=False,
    ):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        tkwargs = {
            "device": torch.device("cuda") if cuda else torch.device("cpu")
        }
        cnt = 1
        for dtype in (torch.float, torch.double):
            tkwargs["dtype"] = dtype
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, **tkwargs)
            mock_gen_candidates.return_value = torch.cat([
                i * torch.ones(1, q, 3, **tkwargs) for i in range(num_restarts)
            ],
                                                         dim=0)
            mock_get_best_candidates.return_value = torch.ones(
                1, q, 3, **tkwargs)
            expected_candidates = mock_get_best_candidates.return_value
            bounds = torch.stack(
                [torch.zeros(3, **tkwargs), 4 * torch.ones(3, **tkwargs)])
            candidates = joint_optimize(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, expected_candidates))

            candidates = joint_optimize(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(num_restarts, q, 3,
                                                     **tkwargs),
            )
            self.assertTrue(
                torch.equal(candidates, mock_gen_candidates.return_value[0]))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)
            cnt += 1
Beispiel #2
0
 def test_joint_optimize(
     self,
     mock_get_best_candidates,
     mock_gen_candidates,
     mock_gen_batch_initial_conditions,
     cuda=False,
 ):
     q = 3
     num_restarts = 2
     raw_samples = 10
     options = {}
     mock_acq_function = MockAcquisitionFunction()
     tkwargs = {"device": torch.device("cuda") if cuda else torch.device("cpu")}
     for dtype in (torch.float, torch.double):
         tkwargs["dtype"] = dtype
         mock_gen_batch_initial_conditions.return_value = torch.zeros(
             num_restarts, q, 3, **tkwargs
         )
         mock_gen_candidates.return_value = torch.cat(
             [i * torch.ones(1, q, 3, **tkwargs) for i in range(num_restarts)], dim=0
         )
         mock_get_best_candidates.return_value = torch.ones(1, q, 3, **tkwargs)
         expected_candidates = mock_get_best_candidates.return_value
         bounds = torch.stack(
             [torch.zeros(3, **tkwargs), 4 * torch.ones(3, **tkwargs)]
         )
         candidates = joint_optimize(
             acq_function=mock_acq_function,
             bounds=bounds,
             q=q,
             num_restarts=num_restarts,
             raw_samples=raw_samples,
             options=options,
         )
         self.assertTrue(torch.equal(candidates, expected_candidates))
Beispiel #3
0
 def test_joint_optimize(self, mock_optimize_acqf):
     kwargs = {
         **self.shared_kwargs,
         "return_best_only": True,
         "batch_initial_conditions": None,
     }
     with warnings.catch_warnings(record=True) as ws:
         candidates, acq_values = joint_optimize(**kwargs)
         self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws))
         self.assertTrue(
             any("joint_optimize is deprecated" in str(w.message) for w in ws)
         )
         mock_optimize_acqf.assert_called_once_with(**kwargs, sequential=False)
         self.assertIsNone(candidates)
         self.assertIsNone(acq_values)