def test_ard(self):
        x1 = torch.tensor([[4, 2], [3, 1], [8, 5]], dtype=torch.float)
        x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
        lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 1, 2)

        kernel = CategoricalKernel(ard_num_dims=2)
        kernel.initialize(lengthscale=lengthscales)
        kernel.eval()

        sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        sc_dists = sc_dists / lengthscales.unsqueeze(-2)
        actual = torch.exp(-sc_dists.mean(-1))
        res = kernel(x1, x2).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # diag
        res = kernel(x1, x2).diag()
        actual = torch.diagonal(actual, dim1=-1, dim2=-2)
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims
        actual = torch.exp(-sc_dists).transpose(-1, -3)
        res = kernel(x1, x2, last_dim_is_batch=True).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims + diag
        res = kernel(x1, x2, last_dim_is_batch=True).diag()
        self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
    def test_ard_separate_batch(self):
        x1 = torch.tensor(
            [
                [[4, 2, 1], [3, 1, 5]],
                [[3, 2, 3], [6, 1, 7]],
            ],
            dtype=torch.float,
        )
        x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float)
        lengthscales = torch.tensor([[[1, 2, 1]], [[2, 1, 0.5]]], dtype=torch.float)

        kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3)
        kernel.initialize(lengthscale=lengthscales)
        kernel.eval()

        sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        sc_dists = sc_dists / lengthscales.unsqueeze(-2)
        actual = torch.exp(-sc_dists.mean(-1))
        res = kernel(x1, x2).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # diag
        res = kernel(x1, x2).diag()
        actual = torch.diagonal(actual, dim1=-1, dim2=-2)
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims
        actual = torch.exp(-sc_dists).transpose(-1, -3)
        res = kernel(x1, x2, last_dim_is_batch=True).evaluate()
        self.assertTrue(torch.allclose(res, actual))

        # batch_dims + diag
        res = kernel(x1, x2, last_dim_is_batch=True).diag()
        self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
    def test_ard_batch(self):
        x1 = torch.tensor(
            [
                [[4, 2, 1], [3, 1, 5]],
                [[3, 2, 3], [6, 1, 7]],
            ],
            dtype=torch.float,
        )
        x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float)
        lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float)

        kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3)
        kernel.initialize(lengthscale=lengthscales)
        kernel.eval()

        sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        sc_dists = sc_dists / lengthscales.unsqueeze(-2)
        actual = torch.exp(-sc_dists.mean(-1))
        res = kernel(x1, x2).evaluate()
        self.assertTrue(torch.allclose(res, actual))
 def test_initialize_lengthscale_batch(self):
     kernel = CategoricalKernel(batch_shape=torch.Size([2]))
     ls_init = torch.tensor([1.0, 2.0])
     kernel.initialize(lengthscale=ls_init)
     actual_value = ls_init.view_as(kernel.lengthscale)
     self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)
 def test_initialize_lengthscale(self):
     kernel = CategoricalKernel()
     kernel.initialize(lengthscale=1)
     actual_value = torch.tensor(1.0).view_as(kernel.lengthscale)
     self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)