def test_ard(self): x1 = torch.tensor([[4, 2], [3, 1], [8, 5]], dtype=torch.float) x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float) lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 1, 2) kernel = CategoricalKernel(ard_num_dims=2) kernel.initialize(lengthscale=lengthscales) kernel.eval() sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3) sc_dists = sc_dists / lengthscales.unsqueeze(-2) actual = torch.exp(-sc_dists.mean(-1)) res = kernel(x1, x2).evaluate() self.assertTrue(torch.allclose(res, actual)) # diag res = kernel(x1, x2).diag() actual = torch.diagonal(actual, dim1=-1, dim2=-2) self.assertTrue(torch.allclose(res, actual)) # batch_dims actual = torch.exp(-sc_dists).transpose(-1, -3) res = kernel(x1, x2, last_dim_is_batch=True).evaluate() self.assertTrue(torch.allclose(res, actual)) # batch_dims + diag res = kernel(x1, x2, last_dim_is_batch=True).diag() self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
def test_ard_separate_batch(self): x1 = torch.tensor( [ [[4, 2, 1], [3, 1, 5]], [[3, 2, 3], [6, 1, 7]], ], dtype=torch.float, ) x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float) lengthscales = torch.tensor([[[1, 2, 1]], [[2, 1, 0.5]]], dtype=torch.float) kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3) kernel.initialize(lengthscale=lengthscales) kernel.eval() sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3) sc_dists = sc_dists / lengthscales.unsqueeze(-2) actual = torch.exp(-sc_dists.mean(-1)) res = kernel(x1, x2).evaluate() self.assertTrue(torch.allclose(res, actual)) # diag res = kernel(x1, x2).diag() actual = torch.diagonal(actual, dim1=-1, dim2=-2) self.assertTrue(torch.allclose(res, actual)) # batch_dims actual = torch.exp(-sc_dists).transpose(-1, -3) res = kernel(x1, x2, last_dim_is_batch=True).evaluate() self.assertTrue(torch.allclose(res, actual)) # batch_dims + diag res = kernel(x1, x2, last_dim_is_batch=True).diag() self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
def test_forward(self): x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float) x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float) lengthscale = 2 kernel = CategoricalKernel().initialize(lengthscale=lengthscale) kernel.eval() sc_dists = (x1.unsqueeze(-2) != x2.unsqueeze(-3)) / lengthscale actual = torch.exp(-sc_dists.mean(-1)) res = kernel(x1, x2).evaluate() self.assertTrue(torch.allclose(res, actual))
def test_active_dims(self): x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float) x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float) lengthscale = 2 kernel = CategoricalKernel(active_dims=[0]).initialize( lengthscale=lengthscale) kernel.eval() dists = x1[:, :1].unsqueeze(-2) != x2[:, :1].unsqueeze(-3) sq_sc_dists = dists**2 / lengthscale**2 actual = torch.exp(-sq_sc_dists.mean(-1)) res = kernel(x1, x2).evaluate() self.assertTrue(torch.allclose(res, actual))
def test_ard_batch(self): x1 = torch.tensor( [ [[4, 2, 1], [3, 1, 5]], [[3, 2, 3], [6, 1, 7]], ], dtype=torch.float, ) x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float) lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float) kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3) kernel.initialize(lengthscale=lengthscales) kernel.eval() sc_dists = x1.unsqueeze(-2) != x2.unsqueeze(-3) sc_dists = sc_dists / lengthscales.unsqueeze(-2) actual = torch.exp(-sc_dists.mean(-1)) res = kernel(x1, x2).evaluate() self.assertTrue(torch.allclose(res, actual))