Пример #1
0
    def test_get_extra_mll_args(self):
        train_X = torch.rand(3, 5)
        train_Y = torch.rand(3)
        model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
        # test ExactMarginalLogLikelihood
        exact_mll = ExactMarginalLogLikelihood(model.likelihood, model)
        exact_extra_args = _get_extra_mll_args(mll=exact_mll)
        self.assertEqual(len(exact_extra_args), 1)
        self.assertTrue(torch.equal(exact_extra_args[0], train_X))

        # test VariationalELBO
        elbo = VariationalELBO(model.likelihood, model, num_data=train_X.shape[0])
        elbo_extra_args = _get_extra_mll_args(mll=elbo)
        self.assertEqual(len(elbo_extra_args), 0)

        # test SumMarginalLogLikelihood
        model2 = ModelListGP(gp_models=[model])
        sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2)
        sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll)
        self.assertEqual(len(sum_mll_extra_args), 1)
        self.assertEqual(len(sum_mll_extra_args[0]), 1)
        self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X))

        # test unsupported MarginalLogLikelihood type
        unsupported_mll = MarginalLogLikelihood(model.likelihood, model)
        with self.assertRaises(ValueError):
            _get_extra_mll_args(mll=unsupported_mll)
Пример #2
0
    def test_get_extra_mll_args(self):
        train_X = torch.rand(3, 5)
        train_Y = torch.rand(3, 1)
        model = SingleTaskGP(train_X=train_X, train_Y=train_Y)

        # test ExactMarginalLogLikelihood
        exact_mll = ExactMarginalLogLikelihood(model.likelihood, model)
        exact_extra_args = _get_extra_mll_args(mll=exact_mll)
        self.assertEqual(len(exact_extra_args), 1)
        self.assertTrue(torch.equal(exact_extra_args[0], train_X))

        # test SumMarginalLogLikelihood
        model2 = ModelListGP(model)
        sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2)
        sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll)
        self.assertEqual(len(sum_mll_extra_args), 1)
        self.assertEqual(len(sum_mll_extra_args[0]), 1)
        self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X))

        # test unsupported MarginalLogLikelihood type
        unsupported_mll = MarginalLogLikelihood(model.likelihood, model)
        unsupported_mll_extra_args = _get_extra_mll_args(mll=unsupported_mll)
        self.assertEqual(unsupported_mll_extra_args, [])