def test_upper_confidence_bound(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): mean = torch.tensor([[0.0]], device=device, dtype=dtype) variance = torch.tensor([[1.0]], device=device, dtype=dtype) mm = MockModel(MockPosterior(mean=mean, variance=variance)) module = UpperConfidenceBound(model=mm, beta=1.0) X = torch.zeros(1, 1, device=device, dtype=dtype) ucb = module(X) ucb_expected = torch.tensor([1.0], device=device, dtype=dtype) self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4)) module = UpperConfidenceBound(model=mm, beta=1.0, maximize=False) X = torch.zeros(1, 1, device=device, dtype=dtype) ucb = module(X) ucb_expected = torch.tensor([-1.0], device=device, dtype=dtype) self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4)) # check for proper error if multi-output model mean2 = torch.rand(1, 2, device=device, dtype=dtype) variance2 = torch.rand(1, 2, device=device, dtype=dtype) mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2)) module2 = UpperConfidenceBound(model=mm2, beta=1.0) with self.assertRaises(UnsupportedError): module2(X)
def test_upper_confidence_bound_batch(self): for dtype in (torch.float, torch.double): mean = torch.tensor([0.0, 0.5], device=self.device, dtype=dtype).view( 2, 1, 1 ) variance = torch.tensor([1.0, 4.0], device=self.device, dtype=dtype).view( 2, 1, 1 ) mm = MockModel(MockPosterior(mean=mean, variance=variance)) module = UpperConfidenceBound(model=mm, beta=1.0) X = torch.zeros(2, 1, 1, device=self.device, dtype=dtype) ucb = module(X) ucb_expected = torch.tensor([1.0, 2.5], device=self.device, dtype=dtype) self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4)) # check for proper error if multi-output model mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype) variance2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype) mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2)) with self.assertRaises(UnsupportedError): UpperConfidenceBound(model=mm2, beta=1.0)
def get_new_points_acq_func_vals(model, acq_fn_label, new_points, best_response, acq_fn_hyperparams=None): if acq_fn_label == 'expected_improvement': acq_func = ExpectedImprovement(model, best_f=best_response, maximize=True) elif acq_fn_label == 'ucb': hyperparams = {'beta': 2} if acq_fn_hyperparams is not None: hyperparams.update(acq_fn_hyperparams) acq_func = UpperConfidenceBound(model, **hyperparams) else: raise NotImplementedError(f'acq_fn_label {acq_fn_label} does not ' 'match implemented types') acq_vals = acq_func( new_points.view((new_points.shape[0], 1, new_points.shape[1]))) return acq_vals
def test_acquisition_functions(self): tkwargs = {"device": self.device, "dtype": torch.double} train_X, train_Y, train_Yvar, model = self._get_data_and_model( infer_noise=True, **tkwargs ) fit_fully_bayesian_model_nuts( model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True ) sampler = IIDNormalSampler(num_samples=2) acquisition_functions = [ ExpectedImprovement(model=model, best_f=train_Y.max()), ProbabilityOfImprovement(model=model, best_f=train_Y.max()), PosteriorMean(model=model), UpperConfidenceBound(model=model, beta=4), qExpectedImprovement(model=model, best_f=train_Y.max(), sampler=sampler), qNoisyExpectedImprovement(model=model, X_baseline=train_X, sampler=sampler), qProbabilityOfImprovement( model=model, best_f=train_Y.max(), sampler=sampler ), qSimpleRegret(model=model, sampler=sampler), qUpperConfidenceBound(model=model, beta=4, sampler=sampler), qNoisyExpectedHypervolumeImprovement( model=ModelListGP(model, model), X_baseline=train_X, ref_point=torch.zeros(2, **tkwargs), sampler=sampler, ), qExpectedHypervolumeImprovement( model=ModelListGP(model, model), ref_point=torch.zeros(2, **tkwargs), sampler=sampler, partitioning=NondominatedPartitioning( ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2]) ), ), ] for acqf in acquisition_functions: for batch_shape in [[5], [6, 5, 2]]: test_X = torch.rand(*batch_shape, 1, 4, **tkwargs) self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))