def test_ard(self): a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float).repeat(2, 1, 1) b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float).repeat(2, 1, 1) lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 1, 2) base_kernel = RBFKernel(ard_num_dims=2) base_kernel.initialize(lengthscale=lengthscales) kernel = ScaleKernel(base_kernel) kernel.initialize(outputscale=torch.tensor([3], dtype=torch.float)) kernel.eval() scaled_a = a.div(lengthscales) scaled_b = b.div(lengthscales) actual = (scaled_a.unsqueeze(-2) - scaled_b.unsqueeze(-3)).pow(2).sum(dim=-1).mul_(-0.5).exp() actual.mul_(3) res = kernel(a, b).evaluate() self.assertLess(torch.norm(res - actual), 1e-5) # Diag res = kernel(a, b).diag() actual = torch.cat([actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) # batch_dims actual = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) actual = actual.pow(2).mul_(-0.5).exp().view(4, 2, 2) actual.mul_(3) res = kernel(a, b, batch_dims=(0, 2)).evaluate() self.assertLess(torch.norm(res - actual), 1e-5) # batch_dims and diag res = kernel(a, b, batch_dims=(0, 2)).diag() actual = torch.cat([actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5)
def test_inherit_active_dims(self): lengthscales = torch.tensor([1, 1], dtype=torch.float) base_kernel = RBFKernel(active_dims=(1, 2), ard_num_dims=2) base_kernel.initialize(lengthscale=lengthscales) kernel = ScaleKernel(base_kernel) kernel.initialize(outputscale=torch.tensor([3], dtype=torch.float)) kernel.eval() self.assertTrue( torch.all(kernel.active_dims == base_kernel.active_dims))
def test_forward_batch_mode(self): a = torch.Tensor([4, 2, 8]).view(1, 3, 1).repeat(4, 1, 1) b = torch.Tensor([0, 2]).view(1, 2, 1).repeat(4, 1, 1) lengthscale = 2 base_kernel = RBFKernel().initialize(log_lengthscale=math.log(lengthscale)) kernel = ScaleKernel(base_kernel, batch_size=4) kernel.initialize(log_outputscale=torch.Tensor([1, 2, 3, 4]).log()) kernel.eval() base_actual = torch.Tensor([[16, 4], [4, 0], [64, 36]]).mul_(-0.5).div_(lengthscale ** 2).exp() actual = base_actual.unsqueeze(0).mul(torch.Tensor([1, 2, 3, 4]).view(4, 1, 1)) res = kernel(a, b).evaluate() self.assertLess(torch.norm(res - actual), 1e-5)
def test_forward(self): a = torch.Tensor([4, 2, 8]).view(3, 1) b = torch.Tensor([0, 2]).view(2, 1) lengthscale = 2 base_kernel = RBFKernel().initialize(log_lengthscale=math.log(lengthscale)) kernel = ScaleKernel(base_kernel) kernel.initialize(log_outputscale=torch.Tensor([3]).log()) kernel.eval() actual = torch.Tensor([[16, 4], [4, 0], [64, 36]]).mul_(-0.5).div_(lengthscale ** 2).exp() actual = actual * 3 res = kernel(a, b).evaluate() self.assertLess(torch.norm(res - actual), 1e-5)
def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float) base_kernel = RBFKernel(batch_shape=torch.Size([2]), ard_num_dims=3) base_kernel.initialize(lengthscale=lengthscales) kernel = ScaleKernel(base_kernel, batch_shape=torch.Size([2])) kernel.initialize(outputscale=torch.tensor([1, 2], dtype=torch.float)) kernel.eval() scaled_a = a.div(lengthscales) scaled_b = b.div(lengthscales) actual = (scaled_a.unsqueeze(-2) - scaled_b.unsqueeze(-3)).pow(2).sum(dim=-1).mul_(-0.5).exp() actual[1].mul_(2) res = kernel(a, b).evaluate() self.assertLess(torch.norm(res - actual), 1e-5) # diag res = kernel(a, b).diag() actual = torch.cat( [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) # batch_dims double_batch_a = scaled_a.transpose(-1, -2) double_batch_b = scaled_b.transpose(-1, -2) actual = double_batch_a.unsqueeze(-1) - double_batch_b.unsqueeze(-2) actual = actual.pow(2).mul_(-0.5).exp() actual[1, :, :, :].mul_(2) res = kernel(a, b, last_dim_is_batch=True).evaluate() self.assertLess(torch.norm(res - actual), 1e-5) # batch_dims and diag res = kernel(a, b, last_dim_is_batch=True).diag() actual = actual.diagonal(dim1=-2, dim2=-1) self.assertLess(torch.norm(res - actual), 1e-5)