示例#1
0
 def test_affine_acquisition_objective(self):
     for batch_shape, m, dtype in itertools.product(
         ([], [3]), (1, 2), (torch.float, torch.double)
     ):
         offset = torch.rand(1).item()
         weights = torch.randn(m, device=self.device, dtype=dtype)
         obj = ScalarizedObjective(weights=weights, offset=offset)
         posterior = _get_test_posterior(
             batch_shape, m=m, device=self.device, dtype=dtype
         )
         mean, covar = posterior.mvn.mean, posterior.mvn.covariance_matrix
         new_posterior = obj(posterior)
         exp_size = torch.Size(batch_shape + [1, 1])
         self.assertEqual(new_posterior.mean.shape, exp_size)
         new_mean_exp = offset + mean @ weights
         self.assertTrue(torch.allclose(new_posterior.mean[..., -1], new_mean_exp))
         self.assertEqual(new_posterior.variance.shape, exp_size)
         new_covar_exp = ((covar @ weights) @ weights).unsqueeze(-1)
         self.assertTrue(
             torch.allclose(new_posterior.variance[..., -1], new_covar_exp)
         )
         # test error
         with self.assertRaises(ValueError):
             ScalarizedObjective(weights=torch.rand(2, m))
         # test evaluate
         Y = torch.rand(2, m, device=self.device, dtype=dtype)
         val = obj.evaluate(Y)
         val_expected = offset + Y @ weights
         self.assertTrue(torch.equal(val, val_expected))
 def test_get_best_f_analytic(self):
     with self.assertRaises(NotImplementedError):
         get_best_f_analytic(training_data=self.nbd_td)
     best_f = get_best_f_analytic(training_data=self.bd_td)
     best_f_expected = self.bd_td.Y.squeeze().max()
     self.assertEqual(best_f, best_f_expected)
     with self.assertRaises(NotImplementedError):
         get_best_f_analytic(training_data=self.bd_td_mo)
     obj = ScalarizedObjective(weights=torch.rand(2))
     best_f = get_best_f_analytic(training_data=self.bd_td_mo,
                                  objective=obj)
     best_f_expected = obj.evaluate(self.bd_td_mo.Y).max()
     self.assertEqual(best_f, best_f_expected)