def test_degree3(self):
        # just make sure it doesn't break here.
        AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=3), 3, 3)
        self.assertEqual(AddK.base_kernel.lengthscale.numel(), 3)
        self.assertEqual(AddK.outputscale.numel(), 3)

        testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float)
        add_k_val = AddK(testvals, testvals).evaluate()

        manual_k1 = ScaleKernel(
            AdditiveKernel(RBFKernel(active_dims=0), RBFKernel(active_dims=1),
                           RBFKernel(active_dims=2)))
        manual_k1.initialize(outputscale=1 / 3)
        manual_k2 = ScaleKernel(
            AdditiveKernel(RBFKernel(active_dims=[0, 1]),
                           RBFKernel(active_dims=[1, 2]),
                           RBFKernel(active_dims=[0, 2])))
        manual_k2.initialize(outputscale=1 / 3)

        manual_k3 = ScaleKernel(AdditiveKernel(RBFKernel()))
        manual_k3.initialize(outputscale=1 / 3)
        manual_k = AdditiveKernel(manual_k1, manual_k2, manual_k3)
        manual_add_k_val = manual_k(testvals, testvals).evaluate()
        # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5)
        self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5))
Example #2
0
    def test_computes_sum_three_radial_basis_function_gradient(self):
        softplus = torch.nn.functional.softplus
        a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
        b = torch.tensor([0, 2, 2], dtype=torch.float).view(3, 1)
        lengthscale = 2

        param = math.log(math.exp(lengthscale) - 1) * torch.ones(3, 3)
        param.requires_grad_()
        diffs = a.expand(3, 3) - b.expand(3, 3).transpose(0, 1)
        actual_output = (-0.5 * (diffs / softplus(param))**2).exp()
        actual_output.backward(torch.eye(3))
        actual_param_grad = param.grad.sum() * 3

        kernel_1 = RBFKernel().initialize(lengthscale=lengthscale)
        kernel_2 = RBFKernel().initialize(lengthscale=lengthscale)
        kernel_3 = RBFKernel().initialize(lengthscale=lengthscale)
        kernel = AdditiveKernel(kernel_1, kernel_2, kernel_3)
        kernel.eval()

        output = kernel(a, b).evaluate()
        output.backward(gradient=torch.eye(3))
        res = (kernel.kernels[0].raw_lengthscale.grad +
               kernel.kernels[1].raw_lengthscale.grad +
               kernel.kernels[2].raw_lengthscale.grad)
        self.assertLess(torch.norm(res - actual_param_grad), 2e-5)
def learn_projections(base_kernels,
                      xs,
                      ys,
                      max_projections=10,
                      mse_threshold=0.0001,
                      post_fit=False,
                      backfit_iters=5,
                      **optim_kwargs):
    n, d = xs.shape
    pred_means = torch.zeros(max_projections, n)
    models = []
    for bf_iter in range(backfit_iters):
        for i in range(max_projections):
            residuals = ys - pred_means[:i, :].sum(
                dim=0) - pred_means[i + 1, :].sum(dim=0)
            if bf_iter == 0:
                with torch.no_grad():
                    coef = torch.pinverse(xs).matmul(residuals).reshape(1, -1)
                base_kernel = base_kernels[i]
                projection = torch.nn.Linear(d, 1, bias=False).to(xs)
                projection.weight.data = coef
                kernel = ScaledProjectionKernel(projection, base_kernel)
                model = ExactGPModel(xs, residuals, GaussianLikelihood(),
                                     kernel).to(xs)
            else:
                model = models[i]
            mll = ExactMarginalLogLikelihood(model.likelihood, model).to(xs)
            # mll.train()
            model.train()
            train_to_convergence(model,
                                 xs,
                                 residuals,
                                 objective=mll,
                                 **optim_kwargs)

            model.eval()
            models.append(model)
            with torch.no_grad():
                pred_mean = model(xs).mean
                pred_means[i, :] = pred_mean
                residuals = residuals - pred_mean
                mse = (residuals**2).mean()
                print(mse.item(), end='; ')
                if mse < mse_threshold:
                    break
    print()
    joint_kernel = AdditiveKernel(*[model.covar_module for model in models])
    joint_model = ExactGPModel(xs, ys, GaussianLikelihood(),
                               joint_kernel).to(xs)

    if post_fit:
        mll = ExactMarginalLogLikelihood(joint_model.likelihood,
                                         joint_model).to(xs)
        train_to_convergence(joint_model,
                             xs,
                             ys,
                             objective=mll,
                             **optim_kwargs)

    return joint_model
Example #4
0
    def test_computes_sum_of_three_radial_basis_function(self):
        a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
        b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)
        lengthscale = 2

        kernel_1 = RBFKernel().initialize(lengthscale=lengthscale)
        kernel_2 = RBFKernel().initialize(lengthscale=lengthscale)
        kernel_3 = RBFKernel().initialize(lengthscale=lengthscale)
        kernel = AdditiveKernel(kernel_1, kernel_2, kernel_3)

        actual = (torch.tensor([[16, 4], [4, 0], [64, 36]],
                               dtype=torch.float).mul_(-0.5).div_(lengthscale**
                                                                  2).exp() * 3)

        kernel.eval()
        res = kernel(a, b).evaluate()
        self.assertLess(torch.norm(res - actual), 2e-5)
Example #5
0
 def __init__(self, train_x, train_y, num_mixtures=10):
     smk = SpectralMixtureKernel(num_mixtures)
     smk.initialize_from_data(train_x, train_y)
     kernel = AdditiveKernel(
         smk,
         PolynomialKernel(2),
         RBFKernel(),
     )
     super(CompositeKernelGP, self).__init__(kernel, train_x, train_y)
     self.mean = gp.means.ConstantMean()
     self.smk = smk
    def test_diag(self):
        AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=3), 3, 2)
        self.assertEqual(AddK.base_kernel.lengthscale.numel(), 3)
        self.assertEqual(AddK.outputscale.numel(), 2)

        testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float)
        add_k_val = AddK(testvals, testvals).diag()

        manual_k1 = ScaleKernel(AdditiveKernel(RBFKernel(active_dims=0),
                                               RBFKernel(active_dims=1),
                                               RBFKernel(active_dims=2)))
        manual_k1.initialize(outputscale=1 / 2)
        manual_k2 = ScaleKernel(AdditiveKernel(RBFKernel(active_dims=[0, 1]),
                                               RBFKernel(active_dims=[1, 2]),
                                               RBFKernel(active_dims=[0, 2])))
        manual_k2.initialize(outputscale=1 / 2)
        manual_k = AdditiveKernel(manual_k1, manual_k2)
        manual_add_k_val = manual_k(testvals, testvals).diag()

        # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5)
        self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5))
    def test_ard(self):
        base_k = RBFKernel(ard_num_dims=3)
        base_k.initialize(lengthscale=[1., 2., 3.])
        AddK = NewtonGirardAdditiveKernel(base_k, 3, max_degree=1)

        testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float)
        add_k_val = AddK(testvals, testvals).evaluate()

        ks = []
        for i in range(3):
            k = RBFKernel(active_dims=i)
            k.initialize(lengthscale=i + 1)
            ks.append(k)
        manual_k = ScaleKernel(AdditiveKernel(*ks))
        manual_k.initialize(outputscale=1.)
        manual_add_k_val = manual_k(testvals, testvals).evaluate()

        # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5)
        self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5))