示例#1
0
 def __init__(self, base_kernel, grid, interpolation_mode=False, active_dims=None):
     super(GridKernel, self).__init__(active_dims=active_dims)
     self.interpolation_mode = interpolation_mode
     self.base_kernel = base_kernel
     self.register_buffer("grid", grid)
     if not self.interpolation_mode:
         self.register_buffer("full_grid", create_data_from_grid(grid))
示例#2
0
 def __init__(
     self, base_kernel: Kernel, grid: List[Tensor], interpolation_mode: bool = False, active_dims: bool = None
 ):
     super(GridKernel, self).__init__(active_dims=active_dims)
     if torch.is_tensor(grid):
         grid = convert_legacy_grid(grid)
     self.interpolation_mode = interpolation_mode
     self.base_kernel = base_kernel
     self.num_dims = len(grid)
     self.register_buffer_list("grid", grid)
     if not self.interpolation_mode:
         self.register_buffer("full_grid", create_data_from_grid(grid))
示例#3
0
    def update_grid(self, grid):
        """
        Supply a new `grid` if it ever changes.
        """
        self.grid.detach().resize_(grid.size()).copy_(grid)

        if not self.interpolation_mode:
            full_grid = create_data_from_grid(self.grid)
            self.full_grid.detach().resize_(full_grid).copy_(full_grid)

        if hasattr(self, "_cached_kernel_mat"):
            del self._cached_kernel_mat
        return self
示例#4
0
    def update_grid(self, grid):
        """
        Supply a new `grid` if it ever changes.
        """
        if torch.is_tensor(grid):
            grid = convert_legacy_grid(grid)

        if len(grid) != self.num_dims:
            raise RuntimeError("New grid should have the same number of dimensions as before.")

        for i in range(self.num_dims):
            setattr(self, f"grid_{i}", grid[i])

        if not self.interpolation_mode:
            self.full_grid = create_data_from_grid(self.grid)

        if hasattr(self, "_cached_kernel_mat"):
            del self._cached_kernel_mat
        return self
示例#5
0
#!/usr/bin/env python3

import torch
import unittest
from gpytorch.kernels import RBFKernel, GridKernel
from gpytorch.lazy import KroneckerProductLazyTensor
from gpytorch.utils.grid import create_data_from_grid

grid = [torch.linspace(0, 1, 5), torch.linspace(0, 2, 3)]
d = len(grid)
grid_data = create_data_from_grid(grid)


class TestGridKernel(unittest.TestCase):
    def test_grid_grid(self):
        base_kernel = RBFKernel()
        kernel = GridKernel(base_kernel, grid)
        grid_covar = kernel(grid_data, grid_data).evaluate_kernel()
        self.assertIsInstance(grid_covar, KroneckerProductLazyTensor)
        grid_eval = kernel(grid_data, grid_data).evaluate()
        actual_eval = base_kernel(grid_data, grid_data).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5)

    def test_nongrid_grid(self):
        base_kernel = RBFKernel()
        data = torch.randn(5, d)
        kernel = GridKernel(base_kernel, grid)
        grid_eval = kernel(grid_data, data).evaluate()
        actual_eval = base_kernel(grid_data, data).evaluate()
        self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5)