Example #1
0
    def test_gen_batch_initial_conditions(self):
        bounds = torch.stack([torch.zeros(2), torch.ones(2)])
        mock_acqf = MockAcquisitionFunction()
        mock_acqf.objective = lambda y: y.squeeze(-1)
        for dtype in (torch.float, torch.double):
            bounds = bounds.to(device=self.device, dtype=dtype)
            mock_acqf.X_baseline = bounds  # for testing sample_around_best
            mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))
            for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product(
                [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False]
            ):
                with mock.patch.object(
                    MockAcquisitionFunction,
                    "__call__",
                    wraps=mock_acqf.__call__,
                ) as mock_acqf_call:
                    batch_initial_conditions = gen_batch_initial_conditions(
                        acq_function=mock_acqf,
                        bounds=bounds,
                        q=1,
                        num_restarts=2,
                        raw_samples=10,
                        fixed_features=ffs,
                        options={
                            "nonnegative": nonnegative,
                            "eta": 0.01,
                            "alpha": 0.1,
                            "seed": seed,
                            "init_batch_limit": init_batch_limit,
                            "sample_around_best": sample_around_best,
                        },
                    )
                    expected_shape = torch.Size([2, 1, 2])
                    self.assertEqual(batch_initial_conditions.shape, expected_shape)
                    self.assertEqual(batch_initial_conditions.device, bounds.device)
                    self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
                    batch_shape = (
                        torch.Size([])
                        if init_batch_limit is None
                        else torch.Size([init_batch_limit])
                    )
                    raw_samps = mock_acqf_call.call_args[0][0]
                    batch_shape = (
                        torch.Size([20 if sample_around_best else 10])
                        if init_batch_limit is None
                        else torch.Size([init_batch_limit])
                    )
                    expected_raw_samps_shape = batch_shape + torch.Size([1, 2])
                    self.assertEqual(raw_samps.shape, expected_raw_samps_shape)

                    if ffs is not None:
                        for idx, val in ffs.items():
                            self.assertTrue(
                                torch.all(batch_initial_conditions[..., idx] == val)
                            )
Example #2
0
    def test_gen_batch_initial_conditions_highdim(self):
        d = 2200  # 2200 * 10 (q) > 21201 (sobol max dim)
        bounds = torch.stack([torch.zeros(d), torch.ones(d)])
        ffs_map = {i: random() for i in range(0, d, 2)}
        mock_acqf = MockAcquisitionFunction()
        mock_acqf.objective = lambda y: y.squeeze(-1)
        for dtype in (torch.float, torch.double):
            bounds = bounds.to(device=self.device, dtype=dtype)
            mock_acqf.X_baseline = bounds  # for testing sample_around_best
            mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1]))

            for nonnegative, seed, ffs, sample_around_best in product(
                [True, False], [None, 1234], [None, ffs_map], [True, False]):
                with warnings.catch_warnings(
                        record=True) as ws, settings.debug(True):
                    batch_initial_conditions = gen_batch_initial_conditions(
                        acq_function=MockAcquisitionFunction(),
                        bounds=bounds,
                        q=10,
                        num_restarts=1,
                        raw_samples=2,
                        fixed_features=ffs,
                        options={
                            "nonnegative": nonnegative,
                            "eta": 0.01,
                            "alpha": 0.1,
                            "seed": seed,
                            "sample_around_best": sample_around_best,
                        },
                    )
                    self.assertTrue(
                        any(
                            issubclass(w.category, SamplingWarning)
                            for w in ws))
                expected_shape = torch.Size([1, 10, d])
                self.assertEqual(batch_initial_conditions.shape,
                                 expected_shape)
                self.assertEqual(batch_initial_conditions.device,
                                 bounds.device)
                self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
                if ffs is not None:
                    for idx, val in ffs.items():
                        self.assertTrue(
                            torch.all(batch_initial_conditions[...,
                                                               idx] == val))