def test_toeplitz_mvn_kl_divergence_forward():
    x = Variable(torch.linspace(0, 1, 5))
    rbf_covar = RBFKernel()
    rbf_covar.initialize(log_lengthscale=-4)
    covar_module = GridInterpolationKernel(rbf_covar)
    covar_module.initialize_interpolation_grid(10, grid_bounds=(0, 1))
    covar_x = covar_module.forward(x.unsqueeze(1), x.unsqueeze(1))

    c = Variable(covar_x.c.data, requires_grad=True)
    mu1 = Variable(torch.randn(10), requires_grad=True)
    mu2 = Variable(torch.randn(10), requires_grad=True)

    T = Variable(torch.zeros(len(c), len(c)))
    for i in range(len(c)):
        for j in range(len(c)):
            T[i, j] = utils.toeplitz.toeplitz_getitem(c, c, i, j)

    U = torch.randn(10, 10).triu()
    U = Variable(U.mul(U.diag().sign().unsqueeze(1).expand_as(U).triu()),
                 requires_grad=True)

    actual = gpytorch.mvn_kl_divergence(mu1, U, mu2, T, num_samples=1000)

    res = gpytorch.mvn_kl_divergence(mu1, U, mu2, covar_x, num_samples=1000)

    assert all(torch.abs((res.data - actual.data) / actual.data) < 0.15)
Beispiel #2
0
    def _init_covar_module(self, covar_module):
        module = covar_module['type'] if covar_module is not None else 'rbf'

        # Index kernel does some scaling, hence, scale kernel is not used
        if module == 'rbf':
            self.covar_module = RBFKernel(ard_num_dims=self.num_dims)
        elif module == 'matern':
            self.covar_module = MaternKernel(nu=1.5,
                                             ard_num_dims=self.num_dims)

        elif module == 'sm':
            self.covar_module = SpectralMixtureKernel(
                num_mixtures=covar_module['num_mixtures'],
                ard_num_dims=self.num_dims)

        elif module == 'kiss':
            self.base_covar_module = RBFKernel()
            self.covar_module = GridInterpolationKernel(self.base_covar_module,
                                                        grid_size=100,
                                                        num_dims=self.num_dims)
        elif module == 'skip':
            self.base_covar_module = RBFKernel()
            self.covar_module = ProductStructureKernel(GridInterpolationKernel(
                self.base_covar_module, grid_size=100, num_dims=1),
                                                       num_dims=self.num_dims)
        else:
            raise NotImplementedError
Beispiel #3
0
def test_interpolated_toeplitz_gp_marginal_log_likelihood_forward():
    x = Variable(torch.linspace(0, 1, 5))
    y = torch.randn(5)
    noise = torch.Tensor([1e-4])
    rbf_covar = RBFKernel()
    rbf_covar.initialize(log_lengthscale=-4)
    covar_module = GridInterpolationKernel(rbf_covar)
    covar_module.initialize_interpolation_grid(10, grid_bounds=(0, 1))
    covar_x = covar_module.forward(x.unsqueeze(1), x.unsqueeze(1))
    c = covar_x.c.data
    T = utils.toeplitz.sym_toeplitz(c)

    W_left = index_coef_to_sparse(covar_x.J_left, covar_x.C_left, len(c))
    W_right = index_coef_to_sparse(covar_x.J_right, covar_x.C_right, len(c))

    W_left_dense = W_left.to_dense()
    W_right_dense = W_right.to_dense()

    WTW = W_left_dense.matmul(T.matmul(W_right_dense.t())) + torch.eye(len(x)) * 1e-4

    quad_form_actual = y.dot(WTW.inverse().matmul(y))
    chol_T = torch.potrf(WTW)
    log_det_actual = chol_T.diag().log().sum() * 2

    actual = -0.5 * (log_det_actual + quad_form_actual + math.log(2 * math.pi) * len(y))

    res = InterpolatedToeplitzGPMarginalLogLikelihood(W_left, W_right, num_samples=1000)(Variable(c),
                                                                                         Variable(y),
                                                                                         Variable(noise)).data
    assert all(torch.abs((res - actual) / actual) < 0.05)
Beispiel #4
0
def test_trace_logdet_quad_form_factory():
    x = Variable(torch.linspace(0, 1, 10))
    rbf_covar = RBFKernel()
    rbf_covar.initialize(log_lengthscale=-4)
    covar_module = GridInterpolationKernel(rbf_covar)
    covar_module.initialize_interpolation_grid(4, grid_bounds=(0, 1))
    c = Variable(covar_module.forward(x.unsqueeze(1), x.unsqueeze(1)).c.data,
                 requires_grad=True)

    T = Variable(torch.zeros(4, 4))
    for i in range(4):
        for j in range(4):
            T[i, j] = utils.toeplitz.toeplitz_getitem(c, c, i, j)

    U = torch.randn(4, 4).triu()
    U = Variable(U.mul(U.diag().sign().unsqueeze(1).expand_as(U).triu()),
                 requires_grad=True)

    mu_diff = Variable(torch.randn(4), requires_grad=True)

    actual = _det(T).log() + mu_diff.dot(
        T.inverse().mv(mu_diff)) + T.inverse().mm(U.t().mm(U)).trace()
    actual.backward()

    actual_c_grad = c.grad.data
    actual_mu_diff_grad = mu_diff.grad.data
    actual_U_grad = U.grad.data

    c.grad.data.fill_(0)
    mu_diff.grad.data.fill_(0)
    U.grad.data.fill_(0)

    def _mm_closure_factory(*args):
        c, = args
        return lambda mat2: utils.toeplitz.sym_toeplitz_mm(c, mat2)

    def _derivative_quadratic_form_factory(*args):
        return lambda left_vector, right_vector: (
            sym_toeplitz_derivative_quadratic_form(left_vector, right_vector
                                                   ), )

    covar_args = (c, )

    res = trace_logdet_quad_form_factory(
        _mm_closure_factory,
        _derivative_quadratic_form_factory)(num_samples=1000)(mu_diff, U,
                                                              *covar_args)
    res.backward()

    res_c_grad = c.grad.data
    res_mu_diff_grad = mu_diff.grad.data
    res_U_grad = U.grad.data

    assert all(torch.abs((res.data - actual.data) / actual.data) < 0.15)
    assert utils.approx_equal(res_c_grad, actual_c_grad)
    assert utils.approx_equal(res_mu_diff_grad, actual_mu_diff_grad)
    assert utils.approx_equal(res_U_grad, actual_U_grad)
def test_toeplitz_mvn_kl_divergence_backward():
    x = Variable(torch.linspace(0, 1, 5))
    rbf_covar = RBFKernel()
    rbf_covar.initialize(log_lengthscale=-4)
    covar_module = GridInterpolationKernel(rbf_covar)
    covar_module.initialize_interpolation_grid(4, grid_bounds=(0, 1))
    covar_x = covar_module.forward(x.unsqueeze(1), x.unsqueeze(1))
    covar_x.c = Variable(covar_x.c.data, requires_grad=True)

    c = covar_x.c
    mu1 = Variable(torch.randn(4), requires_grad=True)
    mu2 = Variable(torch.randn(4), requires_grad=True)

    mu_diff = mu2 - mu1

    T = Variable(torch.zeros(len(c), len(c)))
    for i in range(len(c)):
        for j in range(len(c)):
            T[i, j] = utils.toeplitz.toeplitz_getitem(c, c, i, j)

    U = torch.randn(4, 4).triu()
    U = Variable(U.mul(U.diag().sign().unsqueeze(1).expand_as(U).triu()),
                 requires_grad=True)

    actual = 0.5 * (_det(T).log() + mu_diff.dot(T.inverse().mv(mu_diff)) +
                    T.inverse().mm(U.t().mm(U)).trace() -
                    U.diag().log().sum(0) * 2 - len(mu_diff))
    actual.backward()

    actual_c_grad = c.grad.data.clone()
    actual_mu1_grad = mu1.grad.data.clone()
    actual_mu2_grad = mu2.grad.data.clone()
    actual_U_grad = U.grad.data.clone()

    c.grad.data.fill_(0)
    mu1.grad.data.fill_(0)
    mu2.grad.data.fill_(0)
    U.grad.data.fill_(0)

    res = gpytorch.mvn_kl_divergence(mu1, U, mu2, covar_x, num_samples=1000)
    res.backward()

    res_c_grad = c.grad.data
    res_mu1_grad = mu1.grad.data
    res_mu2_grad = mu2.grad.data
    res_U_grad = U.grad.data

    assert torch.abs(
        (res_c_grad - actual_c_grad)).sum() / actual_c_grad.abs().sum() < 1e-1
    assert torch.abs(
        (res_mu1_grad -
         actual_mu1_grad)).sum() / actual_mu1_grad.abs().sum() < 1e-5
    assert torch.abs(
        (res_mu2_grad -
         actual_mu2_grad)).sum() / actual_mu2_grad.abs().sum() < 1e-5
    assert torch.abs(
        (res_U_grad - actual_U_grad)).sum() / actual_U_grad.abs().sum() < 1e-2
Beispiel #6
0
def test_interpolated_toeplitz_gp_marginal_log_likelihood_backward():
    x = Variable(torch.linspace(0, 1, 5))
    y = Variable(torch.randn(5), requires_grad=True)
    noise = Variable(torch.Tensor([1e-4]), requires_grad=True)

    rbf_covar = RBFKernel()
    rbf_covar.initialize(log_lengthscale=-4)
    covar_module = GridInterpolationKernel(rbf_covar)
    covar_module.initialize_interpolation_grid(10, grid_bounds=(0, 1))
    covar_x = covar_module.forward(x.unsqueeze(1), x.unsqueeze(1))

    c = Variable(covar_x.c.data, requires_grad=True)

    W_left = index_coef_to_sparse(covar_x.J_left, covar_x.C_left, len(c))
    W_right = index_coef_to_sparse(covar_x.J_right, covar_x.C_right, len(c))

    W_left_dense = Variable(W_left.to_dense())
    W_right_dense = Variable(W_right.to_dense())

    T = Variable(torch.zeros(len(c), len(c)))
    for i in range(len(c)):
        for j in range(len(c)):
            T[i, j] = utils.toeplitz.sym_toeplitz_getitem(c, i, j)

    WTW = W_left_dense.matmul(T.matmul(W_right_dense.t())) + Variable(torch.eye(len(x))) * noise

    quad_form_actual = y.dot(WTW.inverse().matmul(y))
    log_det_actual = _det(WTW).log()

    actual_nll = -0.5 * (log_det_actual + quad_form_actual + math.log(2 * math.pi) * len(y))
    actual_nll.backward()

    actual_c_grad = c.grad.data
    actual_y_grad = y.grad.data
    actual_noise_grad = noise.grad.data

    c.grad.data.fill_(0)
    y.grad.data.fill_(0)
    noise.grad.data.fill_(0)

    res = InterpolatedToeplitzGPMarginalLogLikelihood(W_left, W_right, num_samples=1000)(c, y, noise)
    res.backward()

    res_c_grad = c.grad.data
    res_y_grad = y.grad.data
    res_noise_grad = noise.grad.data

    assert utils.approx_equal(actual_c_grad, res_c_grad)
    assert utils.approx_equal(actual_y_grad, res_y_grad)
    assert utils.approx_equal(actual_noise_grad, res_noise_grad)
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(constant_bounds=[-1e-5, 1e-5])
     self.base_covar_module = RBFKernel(log_lengthscale_bounds=(-5, 6))
     self.covar_module = GridInterpolationKernel(self.base_covar_module,
                                                 grid_size=50,
                                                 grid_bounds=[(0, 1)])
Beispiel #8
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(prior=SmoothedBoxPrior(-1e-5, 1e-5))
     self.base_covar_module = ScaleKernel(RBFKernel(lengthscale_prior=SmoothedBoxPrior(exp(-5), exp(6), sigma=0.1)))
     self.grid_covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=50, num_dims=1)
     self.noise_covar_module = WhiteNoiseKernel(variances=torch.ones(100) * 0.001)
     self.covar_module = self.grid_covar_module + self.noise_covar_module
Beispiel #9
0
 def __init__(self, x_train, y_train, likelihood, feature_extractor):
     super(GPRegressionModel, self).__init__(x_train, y_train, likelihood)
     self.mean = ConstantMean()
     self.covar = GridInterpolationKernel(ScaleKernel(RBFKernel(ard_num_dims=2)),
                                          num_dims=2, grid_size=100
                                          )
     self.feature_extractor = feature_extractor
Beispiel #10
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ZeroMean()
     self.base_covar_module = ScaleKernel(RBFKernel(ard_num_dims=2))
     self.covar_module = AdditiveStructureKernel(GridInterpolationKernel(
         self.base_covar_module, grid_size=100, num_dims=1),
                                                 num_dims=2)
Beispiel #11
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(constant_bounds=[-5,5])
     self.base_covar_module = RBFKernel(log_lengthscale_bounds=(-5, 6))
     self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=500,
                                                 grid_bounds=[(-10, 10), (-10, 10)])
     self.register_parameter('log_outputscale', nn.Parameter(torch.Tensor([0])), bounds=(-5,6))
Beispiel #12
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(prior=SmoothedBoxPrior(-1, 1))
     self.base_covar_module = ScaleKernel(RBFKernel())
     self.covar_module = ProductStructureKernel(GridInterpolationKernel(
         self.base_covar_module, grid_size=100, num_dims=1),
                                                num_dims=2)
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(prior=SmoothedBoxPrior(-1, 1))
     self.base_covar_module = RBFKernel(ard_num_dims=2)
     self.covar_module = GridInterpolationKernel(self.base_covar_module,
                                                 grid_size=16,
                                                 num_dims=2)
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(prior=SmoothedBoxPrior(-1, 1))
     self.base_covar_module = ScaleKernel(
         RBFKernel(log_lengthscale_prior=SmoothedBoxPrior(exp(-3), exp(3), sigma=0.1, log_transform=True))
     )
     self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=64, grid_bounds=[(0, 1), (0, 1)])
 def __init__(self):
     super(GPClassificationModel, self).__init__(BernoulliLikelihood())
     self.mean_module = ConstantMean(constant_bounds=[-1e-5, 1e-5])
     self.covar_module = RBFKernel(log_lengthscale_bounds=(-5, 6))
     self.grid_covar_module = GridInterpolationKernel(self.covar_module)
     self.register_parameter('log_outputscale', nn.Parameter(torch.Tensor([0])), bounds=(-5, 6))
     self.initialize_interpolation_grid(50, grid_bounds=[(0, 1)])
    def setUp(self):
        self.xs = torch.tensor([0.20, 0.30, 0.40, 0.10, 0.70],
                               dtype=torch.double)
        self.kernel = GridInterpolationKernel(RBFKernel(),
                                              grid_size=4,
                                              grid_bounds=[(-0.4, 1.4)
                                                           ]).double()
        self.mean_vec = torch.sin(self.xs).double() * 0.0
        self.labels = torch.sin(self.xs) + torch.tensor(
            [0.1, 0.2, -0.1, -0.2, -0.2], dtype=torch.double)
        self.lik = GaussianLikelihood().double()
        self.lik.noise = 0.1
        self.train_train_covar = self.kernel(self.xs).evaluate_kernel()
        self.distr = MultivariateNormal(self.mean_vec, self.train_train_covar)

        e_distr = MultivariateNormal(self.mean_vec, self.train_train_covar)
        e_lik = GaussianLikelihood().double()
        e_lik.noise = 0.1
        self.expected_strategy = InterpolatedPredictionStrategyWithFantasy(
            self.xs, e_distr, self.labels, e_lik)
        self.strategy = ShermanMorrisonOnlineStrategy(self.xs, self.distr,
                                                      self.labels, self.lik)

        self.new_points = torch.tensor([0.5, 0.8], dtype=torch.double)
        self.test_mean = torch.sin(self.new_points) * 0.0
        self.test_train_covar = self.kernel(self.new_points,
                                            self.xs).evaluate_kernel()
        self.test_test_covar = self.kernel(self.new_points,
                                           self.new_points).evaluate_kernel()
 def __init__(self):
     likelihood = GaussianLikelihood(log_noise_bounds=(-3, 3))
     super(Model, self).__init__(likelihood)
     self.mean_module = ConstantMean(constant_bounds=(-1, 1))
     covar_module = RBFKernel()
     self.grid_covar_module = GridInterpolationKernel(covar_module)
     self.initialize_interpolation_grid(10, [(0, 1), (0, 1)])
 def __init__(self, train_x, train_y, likelihood, amountinducing):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean()
     self.base_covar_module = ScaleKernel(RBFKernel())
     dimension = train_x.size(-1)
     self.covar_module = ProductStructureKernel(GridInterpolationKernel(
         self.base_covar_module, grid_size=amountinducing, num_dims=1),
                                                num_dims=dimension)
 def __init__(self):
     likelihood = GaussianLikelihood(log_noise_bounds=(-3, 3))
     super(KissGPModel, self).__init__(likelihood)
     self.mean_module = ConstantMean(constant_bounds=(-1, 1))
     covar_module = RBFKernel(log_lengthscale_bounds=(-100, 100))
     covar_module.log_lengthscale.data = torch.FloatTensor([-2])
     self.grid_covar_module = GridInterpolationKernel(covar_module)
     self.initialize_interpolation_grid(300, grid_bounds=[(0, 1)])
Beispiel #20
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(prior=SmoothedBoxPrior(-1e-5, 1e-5))
     self.base_covar_module = ScaleKernel(
         RBFKernel(log_lengthscale_prior=SmoothedBoxPrior(exp(-5), exp(6), sigma=0.1, log_transform=True))
     )
     self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=50, num_dims=1)
     self.feature_extractor = feature_extractor
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(constant_bounds=[-1e-5, 1e-5])
     self.base_covar_module = RBFKernel(log_lengthscale_bounds=(-5, 6))
     self.grid_covar_module = GridInterpolationKernel(
         self.base_covar_module, grid_size=50, grid_bounds=[(0, 1)])
     self.noise_covar_module = WhiteNoiseKernel(variances=torch.ones(100) *
                                                0.001)
     self.covar_module = self.grid_covar_module + self.noise_covar_module
 def __init__(self, train_x, train_y, likelihood):
     super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = MultitaskMean(ConstantMean(), num_tasks=2)
     self.data_covar_module = GridInterpolationKernel(RBFKernel(),
                                                      grid_size=100,
                                                      num_dims=1)
     self.covar_module = MultitaskKernel(self.data_covar_module,
                                         num_tasks=2,
                                         rank=1)
Beispiel #23
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ZeroMean()
     self.base_covar_module = ScaleKernel(
         RBFKernel(ard_num_dims=2,
                   log_lengthscale_prior=SmoothedBoxPrior(
                       exp(-3), exp(3), sigma=0.1, log_transform=True)))
     self.covar_module = AdditiveStructureKernel(GridInterpolationKernel(
         self.base_covar_module, grid_size=100, num_dims=2),
                                                 num_dims=2)
Beispiel #24
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(prior=SmoothedBoxPrior(-1, 1))
     self.base_covar_module = ScaleKernel(
         RBFKernel(ard_num_dims=2,
                   lengthscale_prior=SmoothedBoxPrior(exp(-3),
                                                      exp(3),
                                                      sigma=0.1)))
     self.covar_module = GridInterpolationKernel(self.base_covar_module,
                                                 grid_size=64,
                                                 num_dims=2)
Beispiel #25
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     self.mean_module = ConstantMean(constant_bounds=[-3, 3])
     # Put a grid interpolation kernel over the RBF kernel
     self.base_covar_module = RBFKernel(log_lengthscale_bounds=(-6, 6))
     self.covar_module = GridInterpolationKernel(self.base_covar_module,
                                                 grid_size=400,
                                                 grid_bounds=[(0, 1.2)])
     # Register kernel lengthscale as parameter
     self.register_parameter('log_outputscale',
                             nn.Parameter(torch.Tensor([0])),
                             bounds=(-6, 6))
Beispiel #26
0
 def __init__(self, train_x, train_y, likelihood):
     super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
     # Near-zero mean
     self.mean_module = ConstantMean(constant_bounds=[-1e-5, 1e-5])
     # GridInterpolationKernel over an ExactGP
     self.base_covar_module = RBFKernel(log_lengthscale_bounds=(-5, 6))
     self.covar_module = GridInterpolationKernel(self.base_covar_module,
                                                 grid_size=30,
                                                 grid_bounds=[(0, 32),
                                                              (0, 32)])
     # Register the log lengthscale as a trainable parametre
     self.register_parameter('log_outputscale',
                             nn.Parameter(torch.Tensor([0])),
                             bounds=(-5, 6))
def test_kp_toeplitz_gp_marginal_log_likelihood_forward():
    x = torch.cat([Variable(torch.linspace(0, 1, 2)).unsqueeze(1)] * 3, 1)
    y = torch.randn(2)
    rbf_module = RBFKernel()
    rbf_module.initialize(log_lengthscale=-2)
    covar_module = GridInterpolationKernel(rbf_module)
    covar_module.eval()
    covar_module.initialize_interpolation_grid(5, [(0, 1), (0, 1), (0, 1)])

    kronecker_var = covar_module.forward(x, x)
    kronecker_var_eval = kronecker_var.evaluate()
    res = kronecker_var.exact_gp_marginal_log_likelihood(Variable(y)).data
    actual = gpytorch.exact_gp_marginal_log_likelihood(kronecker_var_eval,
                                                       Variable(y)).data
    assert all(torch.abs((res - actual) / actual) < 0.05)
def create_full_kernel(d, ard=False, ski=False, grid_size=None, kernel_type='RBF', init_lengthscale_range=(1.0, 1.0),
                       keops=False):
    """Helper to create an RBF kernel object with these options."""
    if ard:
        ard_num_dims = d
    else:
        ard_num_dims = None

    kernel = _map_to_kernel(True, kernel_type, keops, ard_num_dims=ard_num_dims)

    if ard:
        samples = ard_num_dims
    else:
        samples = 1
    kernel.initialize(lengthscale=_sample_from_range(samples, init_lengthscale_range))

    if ski:
        kernel = GridInterpolationKernel(kernel, num_dims=d, grid_size=grid_size)
    return kernel
    def test_standard(self):
        base_kernel = RBFKernel()
        kernel = GridInterpolationKernel(base_kernel,
                                         num_dims=2,
                                         grid_size=128,
                                         grid_bounds=[(-1.2, 1.2)] * 2)

        xs = torch.randn(5, 2).clamp(-1, 1)
        interp_covar = kernel(xs, xs).evaluate_kernel()
        self.assertIsInstance(interp_covar, InterpolatedLazyTensor)

        xs = torch.randn(5, 2).clamp(-1, 1)
        grid_eval = kernel(xs, xs).evaluate()
        actual_eval = base_kernel(xs, xs).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)

        xs = torch.randn(3, 5, 2).clamp(-1, 1)
        grid_eval = kernel(xs, xs).evaluate()
        actual_eval = base_kernel(xs, xs).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)
    def test_batch_base_kernel(self):
        base_kernel = RBFKernel(batch_shape=torch.Size([3]))
        kernel = GridInterpolationKernel(base_kernel,
                                         num_dims=2,
                                         grid_size=128,
                                         grid_bounds=[(-1.2, 1.2)] * 2)

        xs = torch.randn(5, 2).clamp(-1, 1)
        grid_eval = kernel(xs, xs).evaluate()
        actual_eval = base_kernel(xs, xs).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)

        xs = torch.randn(3, 5, 2).clamp(-1, 1)
        grid_eval = kernel(xs, xs).evaluate()
        actual_eval = base_kernel(xs, xs).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)

        xs = torch.randn(4, 3, 5, 2).clamp(-1, 1)
        grid_eval = kernel(xs, xs).evaluate()
        actual_eval = base_kernel(xs, xs).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)