示例#1
0
    def test_gen_batch_initial_conditions(self):
        bounds = torch.stack([torch.zeros(2), torch.ones(2)])
        mock_acqf = MockAcquisitionFunction()
        mock_acqf.objective = lambda y: y.squeeze(-1)
        for dtype in (torch.float, torch.double):
            bounds = bounds.to(device=self.device, dtype=dtype)
            mock_acqf.X_baseline = bounds  # for testing sample_around_best
            mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))
            for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product(
                [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False]
            ):
                with mock.patch.object(
                    MockAcquisitionFunction,
                    "__call__",
                    wraps=mock_acqf.__call__,
                ) as mock_acqf_call:
                    batch_initial_conditions = gen_batch_initial_conditions(
                        acq_function=mock_acqf,
                        bounds=bounds,
                        q=1,
                        num_restarts=2,
                        raw_samples=10,
                        fixed_features=ffs,
                        options={
                            "nonnegative": nonnegative,
                            "eta": 0.01,
                            "alpha": 0.1,
                            "seed": seed,
                            "init_batch_limit": init_batch_limit,
                            "sample_around_best": sample_around_best,
                        },
                    )
                    expected_shape = torch.Size([2, 1, 2])
                    self.assertEqual(batch_initial_conditions.shape, expected_shape)
                    self.assertEqual(batch_initial_conditions.device, bounds.device)
                    self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
                    batch_shape = (
                        torch.Size([])
                        if init_batch_limit is None
                        else torch.Size([init_batch_limit])
                    )
                    raw_samps = mock_acqf_call.call_args[0][0]
                    batch_shape = (
                        torch.Size([20 if sample_around_best else 10])
                        if init_batch_limit is None
                        else torch.Size([init_batch_limit])
                    )
                    expected_raw_samps_shape = batch_shape + torch.Size([1, 2])
                    self.assertEqual(raw_samps.shape, expected_raw_samps_shape)

                    if ffs is not None:
                        for idx, val in ffs.items():
                            self.assertTrue(
                                torch.all(batch_initial_conditions[..., idx] == val)
                            )
示例#2
0
 def test_gen_batch_initial_conditions(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 batch_initial_conditions = gen_batch_initial_conditions(
                     acq_function=MockAcquisitionFunction(),
                     bounds=bounds,
                     q=1,
                     num_restarts=2,
                     raw_samples=10,
                     options={
                         "nonnegative": nonnegative,
                         "eta": 0.01,
                         "alpha": 0.1,
                         "seed": seed,
                     },
                 )
                 expected_shape = torch.Size([2, 1, 2])
                 self.assertEqual(batch_initial_conditions.shape,
                                  expected_shape)
                 self.assertEqual(batch_initial_conditions.device,
                                  bounds.device)
                 self.assertEqual(batch_initial_conditions.dtype,
                                  bounds.dtype)
示例#3
0
 def test_gen_batch_initial_conditions_highdim(self):
     d = 120
     bounds = torch.stack([torch.zeros(d), torch.ones(d)])
     for dtype in (torch.float, torch.double):
         bounds = bounds.to(device=self.device, dtype=dtype)
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 with warnings.catch_warnings(
                         record=True) as ws, settings.debug(True):
                     batch_initial_conditions = gen_batch_initial_conditions(
                         acq_function=MockAcquisitionFunction(),
                         bounds=bounds,
                         q=10,
                         num_restarts=1,
                         raw_samples=2,
                         options={
                             "nonnegative": nonnegative,
                             "eta": 0.01,
                             "alpha": 0.1,
                             "seed": seed,
                         },
                     )
                     self.assertTrue(
                         any(
                             issubclass(w.category, SamplingWarning)
                             for w in ws))
                 expected_shape = torch.Size([1, 10, d])
                 self.assertEqual(batch_initial_conditions.shape,
                                  expected_shape)
                 self.assertEqual(batch_initial_conditions.device,
                                  bounds.device)
                 self.assertEqual(batch_initial_conditions.dtype,
                                  bounds.dtype)
示例#4
0
 def test_gen_batch_initial_conditions_warning(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         samples = torch.zeros(10, 1, 2, device=self.device, dtype=dtype)
         with ExitStack() as es:
             ws = es.enter_context(warnings.catch_warnings(record=True))
             es.enter_context(settings.debug(True))
             es.enter_context(
                 mock.patch(
                     "botorch.optim.initializers.draw_sobol_samples",
                     return_value=samples,
                 ))
             batch_initial_conditions = gen_batch_initial_conditions(
                 acq_function=MockAcquisitionFunction(),
                 bounds=bounds,
                 q=1,
                 num_restarts=2,
                 raw_samples=10,
                 options={"seed": 1234},
             )
             self.assertEqual(len(ws), 1)
             self.assertTrue(
                 any(
                     issubclass(w.category, BadInitialCandidatesWarning)
                     for w in ws))
             self.assertTrue(
                 torch.equal(
                     batch_initial_conditions,
                     torch.zeros(2, 1, 2, device=self.device, dtype=dtype),
                 ))
示例#5
0
 def test_gen_batch_initial_conditions(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 mock_acqf = MockAcquisitionFunction()
                 for init_batch_limit in (None, 1):
                     mock_acqf = MockAcquisitionFunction()
                     with mock.patch.object(
                             MockAcquisitionFunction,
                             "__call__",
                             wraps=mock_acqf.__call__,
                     ) as mock_acqf_call:
                         batch_initial_conditions = gen_batch_initial_conditions(
                             acq_function=mock_acqf,
                             bounds=bounds,
                             q=1,
                             num_restarts=2,
                             raw_samples=10,
                             options={
                                 "nonnegative": nonnegative,
                                 "eta": 0.01,
                                 "alpha": 0.1,
                                 "seed": seed,
                                 "init_batch_limit": init_batch_limit,
                             },
                         )
                         expected_shape = torch.Size([2, 1, 2])
                         self.assertEqual(batch_initial_conditions.shape,
                                          expected_shape)
                         self.assertEqual(batch_initial_conditions.device,
                                          bounds.device)
                         self.assertEqual(batch_initial_conditions.dtype,
                                          bounds.dtype)
                         batch_shape = (torch.Size([])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         raw_samps = mock_acqf_call.call_args[0][0]
                         batch_shape = (torch.Size([10])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         expected_raw_samps_shape = batch_shape + torch.Size(
                             [1, 2])
                         self.assertEqual(raw_samps.shape,
                                          expected_raw_samps_shape)
示例#6
0
 def test_optimize_acqf_sequential(
     self, mock_gen_candidates_scipy, mock_gen_batch_initial_conditions
 ):
     q = 3
     num_restarts = 2
     raw_samples = 10
     options = {}
     for dtype in (torch.float, torch.double):
         mock_acq_function = MockAcquisitionFunction()
         mock_gen_batch_initial_conditions.side_effect = [
             torch.zeros(num_restarts, device=self.device, dtype=dtype)
             for _ in range(q)
         ]
         gcs_return_vals = [
             (
                 torch.tensor([[[1.1, 2.1, 3.1]]], device=self.device, dtype=dtype),
                 torch.tensor([i], device=self.device, dtype=dtype),
             )
             for i in range(q)
         ]
         mock_gen_candidates_scipy.side_effect = gcs_return_vals
         expected_candidates = torch.cat(
             [rv[0][0] for rv in gcs_return_vals], dim=-2
         ).round()
         bounds = torch.stack(
             [
                 torch.zeros(3, device=self.device, dtype=dtype),
                 4 * torch.ones(3, device=self.device, dtype=dtype),
             ]
         )
         inequality_constraints = [
             (torch.tensor([3]), torch.tensor([4]), torch.tensor(5))
         ]
         candidates, acq_value = optimize_acqf(
             acq_function=mock_acq_function,
             bounds=bounds,
             q=q,
             num_restarts=num_restarts,
             raw_samples=raw_samples,
             options=options,
             inequality_constraints=inequality_constraints,
             post_processing_func=rounding_func,
             sequential=True,
         )
         self.assertTrue(torch.equal(candidates, expected_candidates))
         self.assertTrue(
             torch.equal(acq_value, torch.cat([rv[1] for rv in gcs_return_vals]))
         )
     # verify error when using a OneShotAcquisitionFunction
     with self.assertRaises(NotImplementedError):
         optimize_acqf(
             acq_function=mock.Mock(spec=OneShotAcquisitionFunction),
             bounds=bounds,
             q=q,
             num_restarts=num_restarts,
             raw_samples=raw_samples,
             sequential=True,
         )
示例#7
0
 def test_optimize_acqf_sequential_notimplemented(self):
     with self.assertRaises(NotImplementedError):
         optimize_acqf(
             acq_function=MockAcquisitionFunction(),
             bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]),
             q=3,
             num_restarts=2,
             raw_samples=10,
             return_best_only=False,
             sequential=True,
         )
示例#8
0
 def test_optimize_acqf_empty_ff(self):
     with self.assertRaises(ValueError):
         mock_acq_function = MockAcquisitionFunction()
         optimize_acqf_mixed(
             acq_function=mock_acq_function,
             q=1,
             fixed_features_list=[],
             bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]),
             num_restarts=2,
             raw_samples=10,
         )
示例#9
0
class TestDeprecatedOptimize(BotorchTestCase):

    shared_kwargs = {
        "acq_function": MockAcquisitionFunction(),
        "bounds": torch.zeros(2, 2),
        "q": 3,
        "num_restarts": 2,
        "raw_samples": 10,
        "options": {},
        "inequality_constraints": None,
        "equality_constraints": None,
        "fixed_features": None,
        "post_processing_func": None,
    }

    @mock.patch("botorch.optim.optimize.optimize_acqf",
                return_value=(None, None))
    def test_joint_optimize(self, mock_optimize_acqf):
        kwargs = {
            **self.shared_kwargs,
            "return_best_only": True,
            "batch_initial_conditions": None,
        }
        with warnings.catch_warnings(record=True) as ws:
            candidates, acq_values = joint_optimize(**kwargs)
            self.assertTrue(
                any(issubclass(w.category, DeprecationWarning) for w in ws))
            self.assertTrue(
                any("joint_optimize is deprecated" in str(w.message)
                    for w in ws))
            mock_optimize_acqf.assert_called_once_with(**kwargs,
                                                       sequential=False)
            self.assertIsNone(candidates)
            self.assertIsNone(acq_values)

    @mock.patch("botorch.optim.optimize.optimize_acqf",
                return_value=(None, None))
    def test_sequential_optimize(self, mock_optimize_acqf):
        with warnings.catch_warnings(record=True) as ws:
            candidates, acq_values = sequential_optimize(**self.shared_kwargs)
            self.assertTrue(
                any(issubclass(w.category, DeprecationWarning) for w in ws))
            self.assertTrue(
                any("sequential_optimize is deprecated" in str(w.message)
                    for w in ws))
            mock_optimize_acqf.assert_called_once_with(
                **self.shared_kwargs,
                return_best_only=True,
                sequential=True,
                batch_initial_conditions=None,
            )
            self.assertIsNone(candidates)
            self.assertIsNone(acq_values)
示例#10
0
    def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf):
        num_restarts = 2
        raw_samples = 10
        q = 2
        options = {}
        tkwargs = {"device": self.device}
        bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
        mock_acq_functions = [
            MockAcquisitionFunction(),
            MockOneShotEvaluateAcquisitionFunction(),
        ]
        for num_ff, dtype, mock_acq_function in itertools.product(
            [1, 3], (torch.float, torch.double), mock_acq_functions
        ):
            tkwargs["dtype"] = dtype
            mock_optimize_acqf.reset_mock()
            bounds = bounds.to(**tkwargs)

            fixed_features_list = [{i: i * 0.1} for i in range(num_ff)]
            candidate_rvs, exp_candidates, acq_val_rvs = [], [], []
            # generate mock side effects and compute expected outputs
            for _ in range(q):
                candidate_rvs_q = [torch.rand(1, 3, **tkwargs) for _ in range(num_ff)]
                acq_val_rvs_q = [torch.rand(1, **tkwargs) for _ in range(num_ff)]
                best = torch.argmax(torch.stack(acq_val_rvs_q))
                exp_candidates.append(candidate_rvs_q[best])
                candidate_rvs += candidate_rvs_q
                acq_val_rvs += acq_val_rvs_q
            side_effect = list(zip(candidate_rvs, acq_val_rvs))
            mock_optimize_acqf.side_effect = side_effect

            candidates, acq_value = optimize_acqf_mixed(
                acq_function=mock_acq_function,
                q=q,
                fixed_features_list=fixed_features_list,
                bounds=bounds,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                post_processing_func=rounding_func,
            )

            expected_candidates = torch.cat(exp_candidates, dim=-2)
            if isinstance(mock_acq_function, MockOneShotEvaluateAcquisitionFunction):
                expected_acq_value = mock_acq_function.evaluate(
                    expected_candidates, bounds=bounds
                )
            else:
                expected_acq_value = mock_acq_function(expected_candidates)
            self.assertTrue(torch.equal(candidates, expected_candidates))
            self.assertTrue(torch.equal(acq_value, expected_acq_value))
示例#11
0
    def test_gen_batch_initial_conditions_highdim(self):
        d = 2200  # 2200 * 10 (q) > 21201 (sobol max dim)
        bounds = torch.stack([torch.zeros(d), torch.ones(d)])
        ffs_map = {i: random() for i in range(0, d, 2)}
        mock_acqf = MockAcquisitionFunction()
        mock_acqf.objective = lambda y: y.squeeze(-1)
        for dtype in (torch.float, torch.double):
            bounds = bounds.to(device=self.device, dtype=dtype)
            mock_acqf.X_baseline = bounds  # for testing sample_around_best
            mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))

            for nonnegative, seed, ffs, sample_around_best in product(
                [True, False], [None, 1234], [None, ffs_map], [True, False]):
                with warnings.catch_warnings(
                        record=True) as ws, settings.debug(True):
                    batch_initial_conditions = gen_batch_initial_conditions(
                        acq_function=MockAcquisitionFunction(),
                        bounds=bounds,
                        q=10,
                        num_restarts=1,
                        raw_samples=2,
                        fixed_features=ffs,
                        options={
                            "nonnegative": nonnegative,
                            "eta": 0.01,
                            "alpha": 0.1,
                            "seed": seed,
                            "sample_around_best": sample_around_best,
                        },
                    )
                    self.assertTrue(
                        any(
                            issubclass(w.category, SamplingWarning)
                            for w in ws))
                expected_shape = torch.Size([1, 10, d])
                self.assertEqual(batch_initial_conditions.shape,
                                 expected_shape)
                self.assertEqual(batch_initial_conditions.device,
                                 bounds.device)
                self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
                if ffs is not None:
                    for idx, val in ffs.items():
                        self.assertTrue(
                            torch.all(batch_initial_conditions[...,
                                                               idx] == val))
示例#12
0
 def test_gen_batch_initial_conditions_constraints(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         inequality_constraints = [(
             torch.tensor([1], device=self.device, dtype=dtype),
             torch.tensor([-4], device=self.device, dtype=dtype),
             torch.tensor(-3, device=self.device, dtype=dtype),
         )]
         equality_constraints = [(
             torch.tensor([0], device=self.device, dtype=dtype),
             torch.tensor([1], device=self.device, dtype=dtype),
             torch.tensor(0.5, device=self.device, dtype=dtype),
         )]
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 mock_acqf = MockAcquisitionFunction()
                 for init_batch_limit in (None, 1):
                     mock_acqf = MockAcquisitionFunction()
                     with mock.patch.object(
                             MockAcquisitionFunction,
                             "__call__",
                             wraps=mock_acqf.__call__,
                     ) as mock_acqf_call:
                         batch_initial_conditions = gen_batch_initial_conditions(
                             acq_function=mock_acqf,
                             bounds=bounds,
                             q=1,
                             num_restarts=2,
                             raw_samples=10,
                             options={
                                 "nonnegative": nonnegative,
                                 "eta": 0.01,
                                 "alpha": 0.1,
                                 "seed": seed,
                                 "init_batch_limit": init_batch_limit,
                                 "thinning": 2,
                                 "n_burnin": 3,
                             },
                             inequality_constraints=inequality_constraints,
                             equality_constraints=equality_constraints,
                         )
                         expected_shape = torch.Size([2, 1, 2])
                         self.assertEqual(batch_initial_conditions.shape,
                                          expected_shape)
                         self.assertEqual(batch_initial_conditions.device,
                                          bounds.device)
                         self.assertEqual(batch_initial_conditions.dtype,
                                          bounds.dtype)
                         batch_shape = (torch.Size([])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         raw_samps = mock_acqf_call.call_args[0][0]
                         batch_shape = (torch.Size([10])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         expected_raw_samps_shape = batch_shape + torch.Size(
                             [1, 2])
                         self.assertEqual(raw_samps.shape,
                                          expected_raw_samps_shape)
                         self.assertTrue((raw_samps[..., 0] == 0.5).all())
                         self.assertTrue(
                             (-4 * raw_samps[..., 1] >= -3).all())
示例#13
0
    def test_optimize_acqf_joint(self, mock_gen_candidates,
                                 mock_gen_batch_initial_conditions):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        cnt = 1
        for dtype in (torch.float, torch.double):
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, device=self.device, dtype=dtype)
            base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype)
            mock_candidates = torch.cat(
                [i * base_cand for i in range(num_restarts)], dim=0)
            mock_acq_values = num_restarts - torch.arange(
                num_restarts, device=self.device, dtype=dtype)
            mock_gen_candidates.return_value = (mock_candidates,
                                                mock_acq_values)
            bounds = torch.stack([
                torch.zeros(3, device=self.device, dtype=dtype),
                4 * torch.ones(3, device=self.device, dtype=dtype),
            ])
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))

            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(num_restarts,
                                                     q,
                                                     3,
                                                     device=self.device,
                                                     dtype=dtype),
            )
            self.assertTrue(torch.equal(candidates, mock_candidates))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)
            cnt += 1

        # test OneShotAcquisitionFunction
        mock_acq_function = MockOneShotAcquisitionFunction()
        candidates, acq_vals = optimize_acqf(
            acq_function=mock_acq_function,
            bounds=bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options=options,
        )
        self.assertTrue(
            torch.equal(
                candidates,
                mock_acq_function.extract_candidates(mock_candidates[0])))
        self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
示例#14
0
    def test_optimize_acqf_mixed(self, mock_optimize_acqf):
        num_restarts = 2
        raw_samples = 10
        q = 1
        options = {}
        tkwargs = {"device": self.device}
        bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
        mock_acq_function = MockAcquisitionFunction()
        for num_ff, dtype in itertools.product([1, 3],
                                               (torch.float, torch.double)):
            tkwargs["dtype"] = dtype
            mock_optimize_acqf.reset_mock()
            bounds = bounds.to(**tkwargs)

            candidate_rvs = []
            acq_val_rvs = []
            gcs_return_vals = [(torch.rand(1, 3, **tkwargs),
                                torch.rand(1, **tkwargs))
                               for _ in range(num_ff)]
            for rv in gcs_return_vals:
                candidate_rvs.append(rv[0])
                acq_val_rvs.append(rv[1])
            fixed_features_list = [{i: i * 0.1} for i in range(num_ff)]
            side_effect = list(zip(candidate_rvs, acq_val_rvs))
            mock_optimize_acqf.side_effect = side_effect

            candidates, acq_value = optimize_acqf_mixed(
                acq_function=mock_acq_function,
                q=q,
                fixed_features_list=fixed_features_list,
                bounds=bounds,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                post_processing_func=rounding_func,
            )
            # compute expected output
            ff_acq_values = torch.stack(acq_val_rvs)
            best = torch.argmax(ff_acq_values)
            expected_candidates = candidate_rvs[best]
            expected_acq_value = ff_acq_values[best]
            self.assertTrue(torch.equal(candidates, expected_candidates))
            self.assertTrue(torch.equal(acq_value, expected_acq_value))
            # check call arguments for optimize_acqf
            call_args_list = mock_optimize_acqf.call_args_list
            expected_call_args = {
                "acq_function": None,
                "bounds": bounds,
                "q": q,
                "num_restarts": num_restarts,
                "raw_samples": raw_samples,
                "options": options,
                "inequality_constraints": None,
                "equality_constraints": None,
                "fixed_features": None,
                "post_processing_func": rounding_func,
                "batch_initial_conditions": None,
                "return_best_only": True,
                "sequential": False,
            }
            for i in range(len(call_args_list)):
                expected_call_args["fixed_features"] = fixed_features_list[i]
                for k, v in call_args_list[i][1].items():
                    if torch.is_tensor(v):
                        self.assertTrue(torch.equal(expected_call_args[k], v))
                    elif k == "acq_function":
                        self.assertIsInstance(v, MockAcquisitionFunction)
                    else:
                        self.assertEqual(expected_call_args[k], v)
示例#15
0
 def test_optimize_acqf_list(self, mock_optimize_acqf):
     num_restarts = 2
     raw_samples = 10
     options = {}
     tkwargs = {"device": self.device}
     bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
     inequality_constraints = [[
         torch.tensor([3]),
         torch.tensor([4]),
         torch.tensor(5)
     ]]
     # reinitialize so that dtype
     mock_acq_function_1 = MockAcquisitionFunction()
     mock_acq_function_2 = MockAcquisitionFunction()
     mock_acq_function_list = [mock_acq_function_1, mock_acq_function_2]
     for num_acqf, dtype in itertools.product([1, 2],
                                              (torch.float, torch.double)):
         for m in mock_acq_function_list:
             # clear previous X_pending
             m.set_X_pending(None)
         tkwargs["dtype"] = dtype
         inequality_constraints[0] = [
             t.to(**tkwargs) for t in inequality_constraints[0]
         ]
         mock_optimize_acqf.reset_mock()
         bounds = bounds.to(**tkwargs)
         candidate_rvs = []
         acq_val_rvs = []
         gcs_return_vals = [(torch.rand(1, 3, **tkwargs),
                             torch.rand(1, **tkwargs))
                            for _ in range(num_acqf)]
         for rv in gcs_return_vals:
             candidate_rvs.append(rv[0])
             acq_val_rvs.append(rv[1])
         side_effect = list(zip(candidate_rvs, acq_val_rvs))
         mock_optimize_acqf.side_effect = side_effect
         orig_candidates = candidate_rvs[0].clone()
         # Wrap the set_X_pending method for checking that call arguments
         with mock.patch.object(
                 MockAcquisitionFunction,
                 "set_X_pending",
                 wraps=mock_acq_function_1.set_X_pending,
         ) as mock_set_X_pending_1, mock.patch.object(
                 MockAcquisitionFunction,
                 "set_X_pending",
                 wraps=mock_acq_function_2.set_X_pending,
         ) as mock_set_X_pending_2:
             candidates, acq_values = optimize_acqf_list(
                 acq_function_list=mock_acq_function_list[:num_acqf],
                 bounds=bounds,
                 num_restarts=num_restarts,
                 raw_samples=raw_samples,
                 options=options,
                 inequality_constraints=inequality_constraints,
                 post_processing_func=rounding_func,
             )
             # check that X_pending is set correctly in sequential optimization
             if num_acqf > 1:
                 x_pending_call_args_list = mock_set_X_pending_2.call_args_list
                 idxr = torch.ones(num_acqf,
                                   dtype=torch.bool,
                                   device=self.device)
                 for i in range(len(x_pending_call_args_list) - 1):
                     idxr[i] = 0
                     self.assertTrue(
                         torch.equal(x_pending_call_args_list[i][0][0],
                                     orig_candidates[idxr]))
                     idxr[i] = 1
                     orig_candidates[i] = candidate_rvs[i + 1]
             else:
                 mock_set_X_pending_1.assert_not_called()
         # check final candidates
         expected_candidates = (torch.cat(candidate_rvs[-num_acqf:], dim=0)
                                if num_acqf > 1 else candidate_rvs[0])
         self.assertTrue(torch.equal(candidates, expected_candidates))
         # check call arguments for optimize_acqf
         call_args_list = mock_optimize_acqf.call_args_list
         expected_call_args = {
             "acq_function": None,
             "bounds": bounds,
             "q": 1,
             "num_restarts": num_restarts,
             "raw_samples": raw_samples,
             "options": options,
             "inequality_constraints": inequality_constraints,
             "equality_constraints": None,
             "fixed_features": None,
             "post_processing_func": rounding_func,
             "batch_initial_conditions": None,
             "return_best_only": True,
             "sequential": False,
         }
         for i in range(len(call_args_list)):
             expected_call_args["acq_function"] = mock_acq_function_list[i]
             for k, v in call_args_list[i][1].items():
                 if torch.is_tensor(v):
                     self.assertTrue(torch.equal(expected_call_args[k], v))
                 elif k == "acq_function":
                     self.assertIsInstance(mock_acq_function_list[i],
                                           MockAcquisitionFunction)
                 else:
                     self.assertEqual(expected_call_args[k], v)
示例#16
0
 def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
     num_restarts = 2
     raw_samples = 10
     num_cycles = 2
     options = {}
     tkwargs = {"device": self.device}
     bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
     inequality_constraints = [[
         torch.tensor([3]),
         torch.tensor([4]),
         torch.tensor(5)
     ]]
     mock_acq_function = MockAcquisitionFunction()
     for q, dtype in itertools.product([1, 3], (torch.float, torch.double)):
         inequality_constraints[0] = [
             t.to(**tkwargs) for t in inequality_constraints[0]
         ]
         mock_optimize_acqf.reset_mock()
         tkwargs["dtype"] = dtype
         bounds = bounds.to(**tkwargs)
         candidate_rvs = []
         acq_val_rvs = []
         for cycle_j in range(num_cycles):
             gcs_return_vals = [(torch.rand(1, 3, **tkwargs),
                                 torch.rand(1, **tkwargs))
                                for _ in range(q)]
             if cycle_j == 0:
                 # return `q` candidates for first call
                 candidate_rvs.append(
                     torch.cat([rv[0] for rv in gcs_return_vals], dim=-2))
                 acq_val_rvs.append(
                     torch.cat([rv[1] for rv in gcs_return_vals]))
             else:
                 # return 1 candidate for subsequent calls
                 for rv in gcs_return_vals:
                     candidate_rvs.append(rv[0])
                     acq_val_rvs.append(rv[1])
         mock_optimize_acqf.side_effect = list(
             zip(candidate_rvs, acq_val_rvs))
         orig_candidates = candidate_rvs[0].clone()
         # wrap the set_X_pending method for checking that call arguments
         with mock.patch.object(
                 MockAcquisitionFunction,
                 "set_X_pending",
                 wraps=mock_acq_function.set_X_pending,
         ) as mock_set_X_pending:
             candidates, acq_value = optimize_acqf_cyclic(
                 acq_function=mock_acq_function,
                 bounds=bounds,
                 q=q,
                 num_restarts=num_restarts,
                 raw_samples=raw_samples,
                 options=options,
                 inequality_constraints=inequality_constraints,
                 post_processing_func=rounding_func,
                 cyclic_options={"maxiter": num_cycles},
             )
             # check that X_pending is set correctly in cyclic optimization
             if q > 1:
                 x_pending_call_args_list = mock_set_X_pending.call_args_list
                 idxr = torch.ones(q, dtype=torch.bool, device=self.device)
                 for i in range(len(x_pending_call_args_list) - 1):
                     idxr[i] = 0
                     self.assertTrue(
                         torch.equal(x_pending_call_args_list[i][0][0],
                                     orig_candidates[idxr]))
                     idxr[i] = 1
                     orig_candidates[i] = candidate_rvs[i + 1]
                 # check reset to base_X_pendingg
                 self.assertIsNone(x_pending_call_args_list[-1][0][0])
             else:
                 mock_set_X_pending.assert_not_called()
         # check final candidates
         expected_candidates = (torch.cat(candidate_rvs[-q:], dim=0)
                                if q > 1 else candidate_rvs[0])
         self.assertTrue(torch.equal(candidates, expected_candidates))
         # check call arguments for optimize_acqf
         call_args_list = mock_optimize_acqf.call_args_list
         expected_call_args = {
             "acq_function": mock_acq_function,
             "bounds": bounds,
             "num_restarts": num_restarts,
             "raw_samples": raw_samples,
             "options": options,
             "inequality_constraints": inequality_constraints,
             "equality_constraints": None,
             "fixed_features": None,
             "post_processing_func": rounding_func,
             "return_best_only": True,
             "sequential": True,
         }
         orig_candidates = candidate_rvs[0].clone()
         for i in range(len(call_args_list)):
             if i == 0:
                 # first cycle
                 expected_call_args.update({
                     "batch_initial_conditions": None,
                     "q": q
                 })
             else:
                 expected_call_args.update({
                     "batch_initial_conditions":
                     orig_candidates[i - 1:i],
                     "q":
                     1
                 })
                 orig_candidates[i - 1] = candidate_rvs[i]
             for k, v in call_args_list[i][1].items():
                 if torch.is_tensor(v):
                     self.assertTrue(torch.equal(expected_call_args[k], v))
                 elif k == "acq_function":
                     self.assertIsInstance(mock_acq_function,
                                           MockAcquisitionFunction)
                 else:
                     self.assertEqual(expected_call_args[k], v)
示例#17
0
    def test_remove_fixed_features_from_optimization(self):
        fixed_features = {1: 1.0, 3: -1.0}
        b, q, d = 7, 3, 5
        initial_conditions = torch.randn(b, q, d, device=self.device)
        tensor_lower_bounds = torch.randn(q, d, device=self.device)
        tensor_upper_bounds = tensor_lower_bounds + torch.rand(
            q, d, device=self.device)
        old_inequality_constraints = [(
            torch.arange(0, 5, 2, device=self.device),
            torch.rand(3, device=self.device),
            1.0,
        )]
        old_equality_constraints = [(
            torch.arange(0, 3, 1, device=self.device),
            torch.rand(3, device=self.device),
            1.0,
        )]
        acqf = MockAcquisitionFunction()

        def check_bounds_and_init(old_val, new_val):
            if isinstance(old_val, float):
                self.assertEqual(old_val, new_val)
            elif isinstance(old_val, torch.Tensor):
                mask = [(i not in fixed_features.keys()) for i in range(d)]
                self.assertTrue(torch.equal(old_val[..., mask], new_val))
            else:
                self.assertIsNone(new_val)

        def check_cons(old_cons, new_cons):
            if old_cons:  # we don't fixed all dimensions in this test
                new_dim = d - len(fixed_features)
                self.assertTrue(
                    torch.all((new_cons[0][0] <= new_dim)
                              & (new_cons[0][0] >= 0)))
            else:
                self.assertEqual(old_cons, new_cons)

        for (
                lower_bounds,
                upper_bounds,
                inequality_constraints,
                equality_constraints,
        ) in product(
            [None, -1.0, tensor_lower_bounds],
            [None, 1.0, tensor_upper_bounds],
            [None, old_inequality_constraints],
            [None, old_equality_constraints],
        ):
            _no_ff = _remove_fixed_features_from_optimization(
                fixed_features=fixed_features,
                acquisition_function=acqf,
                initial_conditions=initial_conditions,
                lower_bounds=lower_bounds,
                upper_bounds=upper_bounds,
                inequality_constraints=inequality_constraints,
                equality_constraints=equality_constraints,
            )
            self.assertIsInstance(_no_ff.acquisition_function,
                                  FixedFeatureAcquisitionFunction)
            check_bounds_and_init(initial_conditions,
                                  _no_ff.initial_conditions)
            check_bounds_and_init(lower_bounds, _no_ff.lower_bounds)
            check_bounds_and_init(upper_bounds, _no_ff.upper_bounds)
            check_cons(inequality_constraints, _no_ff.inequality_constraints)
            check_cons(equality_constraints, _no_ff.equality_constraints)
示例#18
0
    def test_optimize_acqf_joint(
        self, mock_gen_candidates, mock_gen_batch_initial_conditions
    ):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        cnt = 0
        for dtype in (torch.float, torch.double):
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, device=self.device, dtype=dtype
            )
            base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3)
            mock_candidates = torch.cat(
                [i * base_cand for i in range(num_restarts)], dim=0
            )
            mock_acq_values = num_restarts - torch.arange(
                num_restarts, device=self.device, dtype=dtype
            )
            mock_gen_candidates.return_value = (mock_candidates, mock_acq_values)
            bounds = torch.stack(
                [
                    torch.zeros(3, device=self.device, dtype=dtype),
                    4 * torch.ones(3, device=self.device, dtype=dtype),
                ]
            )
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
            cnt += 1
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test generation with provided initial conditions
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(
                    num_restarts, q, 3, device=self.device, dtype=dtype
                ),
            )
            self.assertTrue(torch.equal(candidates, mock_candidates))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test fixed features
            fixed_features = {0: 0.1}
            mock_candidates[:, 0] = 0.1
            mock_gen_candidates.return_value = (mock_candidates, mock_acq_values)
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                fixed_features=fixed_features,
            )
            self.assertEqual(
                mock_gen_candidates.call_args[1]["fixed_features"], fixed_features
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            cnt += 1
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test trivial case when all features are fixed
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                fixed_features={0: 0.1, 1: 0.2, 2: 0.3},
            )
            self.assertTrue(
                torch.equal(
                    candidates,
                    torch.tensor(
                        [0.1, 0.2, 0.3], device=self.device, dtype=dtype
                    ).expand(3, 3),
                )
            )
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

        # test OneShotAcquisitionFunction
        mock_acq_function = MockOneShotAcquisitionFunction()
        candidates, acq_vals = optimize_acqf(
            acq_function=mock_acq_function,
            bounds=bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options=options,
        )
        self.assertTrue(
            torch.equal(
                candidates, mock_acq_function.extract_candidates(mock_candidates[0])
            )
        )
        self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))