def test_get_extra_mll_args(self): train_X = torch.rand(3, 5) train_Y = torch.rand(3) model = SingleTaskGP(train_X=train_X, train_Y=train_Y) # test ExactMarginalLogLikelihood exact_mll = ExactMarginalLogLikelihood(model.likelihood, model) exact_extra_args = _get_extra_mll_args(mll=exact_mll) self.assertEqual(len(exact_extra_args), 1) self.assertTrue(torch.equal(exact_extra_args[0], train_X)) # test VariationalELBO elbo = VariationalELBO(model.likelihood, model, num_data=train_X.shape[0]) elbo_extra_args = _get_extra_mll_args(mll=elbo) self.assertEqual(len(elbo_extra_args), 0) # test SumMarginalLogLikelihood model2 = ModelListGP(gp_models=[model]) sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2) sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll) self.assertEqual(len(sum_mll_extra_args), 1) self.assertEqual(len(sum_mll_extra_args[0]), 1) self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X)) # test unsupported MarginalLogLikelihood type unsupported_mll = MarginalLogLikelihood(model.likelihood, model) with self.assertRaises(ValueError): _get_extra_mll_args(mll=unsupported_mll)
def test_get_extra_mll_args(self): train_X = torch.rand(3, 5) train_Y = torch.rand(3, 1) model = SingleTaskGP(train_X=train_X, train_Y=train_Y) # test ExactMarginalLogLikelihood exact_mll = ExactMarginalLogLikelihood(model.likelihood, model) exact_extra_args = _get_extra_mll_args(mll=exact_mll) self.assertEqual(len(exact_extra_args), 1) self.assertTrue(torch.equal(exact_extra_args[0], train_X)) # test SumMarginalLogLikelihood model2 = ModelListGP(model) sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2) sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll) self.assertEqual(len(sum_mll_extra_args), 1) self.assertEqual(len(sum_mll_extra_args[0]), 1) self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X)) # test unsupported MarginalLogLikelihood type unsupported_mll = MarginalLogLikelihood(model.likelihood, model) unsupported_mll_extra_args = _get_extra_mll_args(mll=unsupported_mll) self.assertEqual(unsupported_mll_extra_args, [])