Example #1
0
    def test_ard(self):
        x1 = torch.tensor([[4, 2], [3, 1], [8, 5]], dtype=torch.float)
        x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
        lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 1, 2)

        kernel = CategoricalKernel(ard_num_dims=2)
        kernel.initialize(lengthscale=lengthscales)
        kernel.eval()

        sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        sc_dists = sc_dists / lengthscales.unsqueeze(-2)
        actual = torch.exp(-sc_dists.mean(-1))
        res = kernel(x1, x2).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # diag
        res = kernel(x1, x2).diag()
        actual = torch.diagonal(actual, dim1=-1, dim2=-2)
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims
        actual = torch.exp(-sc_dists).transpose(-1, -3)
        res = kernel(x1, x2, last_dim_is_batch=True).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims + diag
        res = kernel(x1, x2, last_dim_is_batch=True).diag()
        self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
Example #2
0
    def test_ard_separate_batch(self):
        x1 = torch.tensor(
            [
                [[4, 2, 1], [3, 1, 5]],
                [[3, 2, 3], [6, 1, 7]],
            ],
            dtype=torch.float,
        )
        x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float)
        lengthscales = torch.tensor([[[1, 2, 1]], [[2, 1, 0.5]]], dtype=torch.float)

        kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3)
        kernel.initialize(lengthscale=lengthscales)
        kernel.eval()

        sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        sc_dists = sc_dists / lengthscales.unsqueeze(-2)
        actual = torch.exp(-sc_dists.mean(-1))
        res = kernel(x1, x2).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # diag
        res = kernel(x1, x2).diag()
        actual = torch.diagonal(actual, dim1=-1, dim2=-2)
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims
        actual = torch.exp(-sc_dists).transpose(-1, -3)
        res = kernel(x1, x2, last_dim_is_batch=True).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims + diag
        res = kernel(x1, x2, last_dim_is_batch=True).diag()
        self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
Example #3
0
 def test_forward(self):
     x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float)
     x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
     lengthscale = 2
     kernel = CategoricalKernel().initialize(lengthscale=lengthscale)
     kernel.eval()
     sc_dists = (x1.unsqueeze(-2) != x2.unsqueeze(-3)) / lengthscale
     actual = torch.exp(-sc_dists.mean(-1))
     res = kernel(x1, x2).evaluate()
     self.assertTrue(torch.allclose(res, actual))
Example #4
0
 def test_active_dims(self):
     x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float)
     x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
     lengthscale = 2
     kernel = CategoricalKernel(active_dims=[0]).initialize(
         lengthscale=lengthscale)
     kernel.eval()
     dists = x1[:, :1].unsqueeze(-2) != x2[:, :1].unsqueeze(-3)
     sq_sc_dists = dists**2 / lengthscale**2
     actual = torch.exp(-sq_sc_dists.mean(-1))
     res = kernel(x1, x2).evaluate()
     self.assertTrue(torch.allclose(res, actual))
Example #5
0
    def test_ard_batch(self):
        x1 = torch.tensor(
            [
                [[4, 2, 1], [3, 1, 5]],
                [[3, 2, 3], [6, 1, 7]],
            ],
            dtype=torch.float,
        )
        x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float)
        lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float)

        kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3)
        kernel.initialize(lengthscale=lengthscales)
        kernel.eval()

        sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        sc_dists = sc_dists / lengthscales.unsqueeze(-2)
        actual = torch.exp(-sc_dists.mean(-1))
        res = kernel(x1, x2).evaluate()
        self.assertTrue(torch.allclose(res, actual))