Beispiel #1
0
 def test_optimize_acqf_list(self, mock_optimize_acqf):
     num_restarts = 2
     raw_samples = 10
     options = {}
     tkwargs = {"device": self.device}
     bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
     inequality_constraints = [
         [torch.tensor([3]), torch.tensor([4]), torch.tensor(5)]
     ]
     mock_acq_function_1 = MockAcquisitionFunction()
     mock_acq_function_2 = MockAcquisitionFunction()
     mock_acq_function_list = [mock_acq_function_1, mock_acq_function_2]
     for num_acqf, dtype in itertools.product([1, 2], (torch.float, torch.double)):
         inequality_constraints[0] = [
             t.to(**tkwargs) for t in inequality_constraints[0]
         ]
         mock_optimize_acqf.reset_mock()
         tkwargs["dtype"] = dtype
         bounds = bounds.to(**tkwargs)
         candidate_rvs = []
         acq_val_rvs = []
         gcs_return_vals = [
             (torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs))
             for _ in range(num_acqf)
         ]
         for rv in gcs_return_vals:
             candidate_rvs.append(rv[0])
             acq_val_rvs.append(rv[1])
         side_effect = list(zip(candidate_rvs, acq_val_rvs))
         mock_optimize_acqf.side_effect = side_effect
         orig_candidates = candidate_rvs[0].clone()
         # Wrap the set_X_pending method for checking that call arguments
         with mock.patch.object(
             MockAcquisitionFunction,
             "set_X_pending",
             wraps=mock_acq_function_1.set_X_pending,
         ) as mock_set_X_pending_1, mock.patch.object(
             MockAcquisitionFunction,
             "set_X_pending",
             wraps=mock_acq_function_2.set_X_pending,
         ) as mock_set_X_pending_2:
             candidates, acq_values = optimize_acqf_list(
                 acq_function_list=mock_acq_function_list[:num_acqf],
                 bounds=bounds,
                 num_restarts=num_restarts,
                 raw_samples=raw_samples,
                 options=options,
                 inequality_constraints=inequality_constraints,
                 post_processing_func=rounding_func,
             )
             # check that X_pending is set correctly in sequential optimization
             if num_acqf > 1:
                 x_pending_call_args_list = mock_set_X_pending_2.call_args_list
                 idxr = torch.ones(num_acqf, dtype=torch.bool, device=self.device)
                 for i in range(len(x_pending_call_args_list) - 1):
                     idxr[i] = 0
                     self.assertTrue(
                         torch.equal(
                             x_pending_call_args_list[i][0][0], orig_candidates[idxr]
                         )
                     )
                     idxr[i] = 1
                     orig_candidates[i] = candidate_rvs[i + 1]
             else:
                 mock_set_X_pending_1.assert_not_called()
         # check final candidates
         expected_candidates = (
             torch.cat(candidate_rvs[-num_acqf:], dim=0)
             if num_acqf > 1
             else candidate_rvs[0]
         )
         self.assertTrue(torch.equal(candidates, expected_candidates))
         # check call arguments for optimize_acqf
         call_args_list = mock_optimize_acqf.call_args_list
         expected_call_args = {
             "acq_function": None,
             "bounds": bounds,
             "q": 1,
             "num_restarts": num_restarts,
             "raw_samples": raw_samples,
             "options": options,
             "inequality_constraints": inequality_constraints,
             "equality_constraints": None,
             "fixed_features": None,
             "post_processing_func": rounding_func,
             "batch_initial_conditions": None,
             "return_best_only": True,
             "sequential": False,
         }
         for i in range(len(call_args_list)):
             expected_call_args["acq_function"] = mock_acq_function_list[i]
             for k, v in call_args_list[i][1].items():
                 if torch.is_tensor(v):
                     self.assertTrue(torch.equal(expected_call_args[k], v))
                 elif k == "acq_function":
                     self.assertIsInstance(
                         mock_acq_function_list[i], MockAcquisitionFunction
                     )
                 else:
                     self.assertEqual(expected_call_args[k], v)
Beispiel #2
0
 def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
     num_restarts = 2
     raw_samples = 10
     num_cycles = 2
     options = {}
     tkwargs = {"device": self.device}
     bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
     inequality_constraints = [
         [torch.tensor([3]), torch.tensor([4]), torch.tensor(5)]
     ]
     mock_acq_function = MockAcquisitionFunction()
     for q, dtype in itertools.product([1, 3], (torch.float, torch.double)):
         inequality_constraints[0] = [
             t.to(**tkwargs) for t in inequality_constraints[0]
         ]
         mock_optimize_acqf.reset_mock()
         tkwargs["dtype"] = dtype
         bounds = bounds.to(**tkwargs)
         candidate_rvs = []
         acq_val_rvs = []
         for cycle_j in range(num_cycles):
             gcs_return_vals = [
                 (torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs))
                 for _ in range(q)
             ]
             if cycle_j == 0:
                 # return `q` candidates for first call
                 candidate_rvs.append(
                     torch.cat([rv[0] for rv in gcs_return_vals], dim=-2)
                 )
                 acq_val_rvs.append(torch.cat([rv[1] for rv in gcs_return_vals]))
             else:
                 # return 1 candidate for subsequent calls
                 for rv in gcs_return_vals:
                     candidate_rvs.append(rv[0])
                     acq_val_rvs.append(rv[1])
         mock_optimize_acqf.side_effect = list(zip(candidate_rvs, acq_val_rvs))
         orig_candidates = candidate_rvs[0].clone()
         # wrap the set_X_pending method for checking that call arguments
         with mock.patch.object(
             MockAcquisitionFunction,
             "set_X_pending",
             wraps=mock_acq_function.set_X_pending,
         ) as mock_set_X_pending:
             candidates, acq_value = optimize_acqf_cyclic(
                 acq_function=mock_acq_function,
                 bounds=bounds,
                 q=q,
                 num_restarts=num_restarts,
                 raw_samples=raw_samples,
                 options=options,
                 inequality_constraints=inequality_constraints,
                 post_processing_func=rounding_func,
                 cyclic_options={"maxiter": num_cycles},
             )
             # check that X_pending is set correctly in cyclic optimization
             if q > 1:
                 x_pending_call_args_list = mock_set_X_pending.call_args_list
                 idxr = torch.ones(q, dtype=torch.bool, device=self.device)
                 for i in range(len(x_pending_call_args_list) - 1):
                     idxr[i] = 0
                     self.assertTrue(
                         torch.equal(
                             x_pending_call_args_list[i][0][0], orig_candidates[idxr]
                         )
                     )
                     idxr[i] = 1
                     orig_candidates[i] = candidate_rvs[i + 1]
                 # check reset to base_X_pendingg
                 self.assertIsNone(x_pending_call_args_list[-1][0][0])
             else:
                 mock_set_X_pending.assert_not_called()
         # check final candidates
         expected_candidates = (
             torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0]
         )
         self.assertTrue(torch.equal(candidates, expected_candidates))
         # check call arguments for optimize_acqf
         call_args_list = mock_optimize_acqf.call_args_list
         expected_call_args = {
             "acq_function": mock_acq_function,
             "bounds": bounds,
             "num_restarts": num_restarts,
             "raw_samples": raw_samples,
             "options": options,
             "inequality_constraints": inequality_constraints,
             "equality_constraints": None,
             "fixed_features": None,
             "post_processing_func": rounding_func,
             "return_best_only": True,
             "sequential": True,
         }
         orig_candidates = candidate_rvs[0].clone()
         for i in range(len(call_args_list)):
             if i == 0:
                 # first cycle
                 expected_call_args.update(
                     {"batch_initial_conditions": None, "q": q}
                 )
             else:
                 expected_call_args.update(
                     {"batch_initial_conditions": orig_candidates[i - 1 : i], "q": 1}
                 )
                 orig_candidates[i - 1] = candidate_rvs[i]
             for k, v in call_args_list[i][1].items():
                 if torch.is_tensor(v):
                     self.assertTrue(torch.equal(expected_call_args[k], v))
                 elif k == "acq_function":
                     self.assertIsInstance(
                         mock_acq_function, MockAcquisitionFunction
                     )
                 else:
                     self.assertEqual(expected_call_args[k], v)
Beispiel #3
0
 def test_gen_batch_initial_conditions_constraints(self):
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]],
                               device=self.device,
                               dtype=dtype)
         inequality_constraints = [(
             torch.tensor([1], device=self.device, dtype=dtype),
             torch.tensor([-4], device=self.device, dtype=dtype),
             torch.tensor(-3, device=self.device, dtype=dtype),
         )]
         equality_constraints = [(
             torch.tensor([0], device=self.device, dtype=dtype),
             torch.tensor([1], device=self.device, dtype=dtype),
             torch.tensor(0.5, device=self.device, dtype=dtype),
         )]
         for nonnegative in (True, False):
             for seed in (None, 1234):
                 mock_acqf = MockAcquisitionFunction()
                 for init_batch_limit in (None, 1):
                     mock_acqf = MockAcquisitionFunction()
                     with mock.patch.object(
                             MockAcquisitionFunction,
                             "__call__",
                             wraps=mock_acqf.__call__,
                     ) as mock_acqf_call:
                         batch_initial_conditions = gen_batch_initial_conditions(
                             acq_function=mock_acqf,
                             bounds=bounds,
                             q=1,
                             num_restarts=2,
                             raw_samples=10,
                             options={
                                 "nonnegative": nonnegative,
                                 "eta": 0.01,
                                 "alpha": 0.1,
                                 "seed": seed,
                                 "init_batch_limit": init_batch_limit,
                                 "thinning": 2,
                                 "n_burnin": 3,
                             },
                             inequality_constraints=inequality_constraints,
                             equality_constraints=equality_constraints,
                         )
                         expected_shape = torch.Size([2, 1, 2])
                         self.assertEqual(batch_initial_conditions.shape,
                                          expected_shape)
                         self.assertEqual(batch_initial_conditions.device,
                                          bounds.device)
                         self.assertEqual(batch_initial_conditions.dtype,
                                          bounds.dtype)
                         batch_shape = (torch.Size([])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         raw_samps = mock_acqf_call.call_args[0][0]
                         batch_shape = (torch.Size([10])
                                        if init_batch_limit is None else
                                        torch.Size([init_batch_limit]))
                         expected_raw_samps_shape = batch_shape + torch.Size(
                             [1, 2])
                         self.assertEqual(raw_samps.shape,
                                          expected_raw_samps_shape)
                         self.assertTrue((raw_samps[..., 0] == 0.5).all())
                         self.assertTrue(
                             (-4 * raw_samps[..., 1] >= -3).all())
Beispiel #4
0
    def test_optimize_acqf_joint(
        self, mock_gen_candidates, mock_gen_batch_initial_conditions
    ):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        cnt = 1
        for dtype in (torch.float, torch.double):
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, device=self.device, dtype=dtype
            )
            base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype)
            mock_candidates = torch.cat(
                [i * base_cand for i in range(num_restarts)], dim=0
            )
            mock_acq_values = num_restarts - torch.arange(
                num_restarts, device=self.device, dtype=dtype
            )
            mock_gen_candidates.return_value = (mock_candidates, mock_acq_values)
            bounds = torch.stack(
                [
                    torch.zeros(3, device=self.device, dtype=dtype),
                    4 * torch.ones(3, device=self.device, dtype=dtype),
                ]
            )
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))

            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(
                    num_restarts, q, 3, device=self.device, dtype=dtype
                ),
            )
            self.assertTrue(torch.equal(candidates, mock_candidates))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)
            cnt += 1

        # test OneShotAcquisitionFunction
        mock_acq_function = MockOneShotAcquisitionFunction()
        candidates, acq_vals = optimize_acqf(
            acq_function=mock_acq_function,
            bounds=bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options=options,
        )
        self.assertTrue(
            torch.equal(
                candidates, mock_acq_function.extract_candidates(mock_candidates[0])
            )
        )
        self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
Beispiel #5
0
    def test_optimize_acqf_joint(self, mock_gen_candidates,
                                 mock_gen_batch_initial_conditions):
        q = 3
        num_restarts = 2
        raw_samples = 10
        options = {}
        mock_acq_function = MockAcquisitionFunction()
        cnt = 0
        for dtype in (torch.float, torch.double):
            mock_gen_batch_initial_conditions.return_value = torch.zeros(
                num_restarts, q, 3, device=self.device, dtype=dtype)
            base_cand = torch.arange(3, device=self.device,
                                     dtype=dtype).expand(1, q, 3)
            mock_candidates = torch.cat(
                [i * base_cand for i in range(num_restarts)], dim=0)
            mock_acq_values = num_restarts - torch.arange(
                num_restarts, device=self.device, dtype=dtype)
            mock_gen_candidates.return_value = (mock_candidates,
                                                mock_acq_values)
            bounds = torch.stack([
                torch.zeros(3, device=self.device, dtype=dtype),
                4 * torch.ones(3, device=self.device, dtype=dtype),
            ])
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
            )
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
            cnt += 1
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test generation with provided initial conditions
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                return_best_only=False,
                batch_initial_conditions=torch.zeros(num_restarts,
                                                     q,
                                                     3,
                                                     device=self.device,
                                                     dtype=dtype),
            )
            self.assertTrue(torch.equal(candidates, mock_candidates))
            self.assertTrue(torch.equal(acq_vals, mock_acq_values))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test fixed features
            fixed_features = {0: 0.1}
            mock_candidates[:, 0] = 0.1
            mock_gen_candidates.return_value = (mock_candidates,
                                                mock_acq_values)
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                fixed_features=fixed_features,
            )
            self.assertEqual(
                mock_gen_candidates.call_args[1]["fixed_features"],
                fixed_features)
            self.assertTrue(torch.equal(candidates, mock_candidates[0]))
            cnt += 1
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

            # test trivial case when all features are fixed
            candidates, acq_vals = optimize_acqf(
                acq_function=mock_acq_function,
                bounds=bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                fixed_features={
                    0: 0.1,
                    1: 0.2,
                    2: 0.3
                },
            )
            self.assertTrue(
                torch.equal(
                    candidates,
                    torch.tensor([0.1, 0.2, 0.3],
                                 device=self.device,
                                 dtype=dtype).expand(3, 3),
                ))
            self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

        # test OneShotAcquisitionFunction
        mock_acq_function = MockOneShotAcquisitionFunction()
        candidates, acq_vals = optimize_acqf(
            acq_function=mock_acq_function,
            bounds=bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options=options,
        )
        self.assertTrue(
            torch.equal(
                candidates,
                mock_acq_function.extract_candidates(mock_candidates[0])))
        self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
Beispiel #6
0
    def test_optimize_acqf_mixed(self, mock_optimize_acqf):
        num_restarts = 2
        raw_samples = 10
        q = 1
        options = {}
        tkwargs = {"device": self.device}
        bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
        mock_acq_function = MockAcquisitionFunction()
        for num_ff, dtype in itertools.product([1, 3],
                                               (torch.float, torch.double)):
            tkwargs["dtype"] = dtype
            mock_optimize_acqf.reset_mock()
            bounds = bounds.to(**tkwargs)

            candidate_rvs = []
            acq_val_rvs = []
            gcs_return_vals = [(torch.rand(1, 3, **tkwargs),
                                torch.rand(1, **tkwargs))
                               for _ in range(num_ff)]
            for rv in gcs_return_vals:
                candidate_rvs.append(rv[0])
                acq_val_rvs.append(rv[1])
            fixed_features_list = [{i: i * 0.1} for i in range(num_ff)]
            side_effect = list(zip(candidate_rvs, acq_val_rvs))
            mock_optimize_acqf.side_effect = side_effect

            candidates, acq_value = optimize_acqf_mixed(
                acq_function=mock_acq_function,
                q=q,
                fixed_features_list=fixed_features_list,
                bounds=bounds,
                num_restarts=num_restarts,
                raw_samples=raw_samples,
                options=options,
                post_processing_func=rounding_func,
            )
            # compute expected output
            ff_acq_values = torch.stack(acq_val_rvs)
            best = torch.argmax(ff_acq_values)
            expected_candidates = candidate_rvs[best]
            expected_acq_value = ff_acq_values[best]
            self.assertTrue(torch.equal(candidates, expected_candidates))
            self.assertTrue(torch.equal(acq_value, expected_acq_value))
            # check call arguments for optimize_acqf
            call_args_list = mock_optimize_acqf.call_args_list
            expected_call_args = {
                "acq_function": None,
                "bounds": bounds,
                "q": q,
                "num_restarts": num_restarts,
                "raw_samples": raw_samples,
                "options": options,
                "inequality_constraints": None,
                "equality_constraints": None,
                "fixed_features": None,
                "post_processing_func": rounding_func,
                "batch_initial_conditions": None,
                "return_best_only": True,
                "sequential": False,
            }
            for i in range(len(call_args_list)):
                expected_call_args["fixed_features"] = fixed_features_list[i]
                for k, v in call_args_list[i][1].items():
                    if torch.is_tensor(v):
                        self.assertTrue(torch.equal(expected_call_args[k], v))
                    elif k == "acq_function":
                        self.assertIsInstance(v, MockAcquisitionFunction)
                    else:
                        self.assertEqual(expected_call_args[k], v)