def test_gen_batch_initial_conditions(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) for dtype in (torch.float, torch.double): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] ): with mock.patch.object( MockAcquisitionFunction, "__call__", wraps=mock_acqf.__call__, ) as mock_acqf_call: batch_initial_conditions = gen_batch_initial_conditions( acq_function=mock_acqf, bounds=bounds, q=1, num_restarts=2, raw_samples=10, fixed_features=ffs, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "init_batch_limit": init_batch_limit, "sample_around_best": sample_around_best, }, ) expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) batch_shape = ( torch.Size([]) if init_batch_limit is None else torch.Size([init_batch_limit]) ) raw_samps = mock_acqf_call.call_args[0][0] batch_shape = ( torch.Size([20 if sample_around_best else 10]) if init_batch_limit is None else torch.Size([init_batch_limit]) ) expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) self.assertEqual(raw_samps.shape, expected_raw_samps_shape) if ffs is not None: for idx, val in ffs.items(): self.assertTrue( torch.all(batch_initial_conditions[..., idx] == val) )
def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) bounds = torch.stack([torch.zeros(d), torch.ones(d)]) ffs_map = {i: random() for i in range(0, d, 2)} mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) for dtype in (torch.float, torch.double): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) for nonnegative, seed, ffs, sample_around_best in product( [True, False], [None, 1234], [None, ffs_map], [True, False]): with warnings.catch_warnings( record=True) as ws, settings.debug(True): batch_initial_conditions = gen_batch_initial_conditions( acq_function=MockAcquisitionFunction(), bounds=bounds, q=10, num_restarts=1, raw_samples=2, fixed_features=ffs, options={ "nonnegative": nonnegative, "eta": 0.01, "alpha": 0.1, "seed": seed, "sample_around_best": sample_around_best, }, ) self.assertTrue( any( issubclass(w.category, SamplingWarning) for w in ws)) expected_shape = torch.Size([1, 10, d]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) if ffs is not None: for idx, val in ffs.items(): self.assertTrue( torch.all(batch_initial_conditions[..., idx] == val))