def test_computes_linear_function_square_batch(self): a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]], dtype=torch.float) kernel = LinearKernel().initialize(variance=1.0) kernel.eval() actual = torch.matmul(a, a.transpose(-1, -2)) res = kernel(a, a).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # diag res = kernel(a, a).diag() actual = torch.cat( [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-4) # batch_dims dim_group_a = a dim_group_a = dim_group_a.unsqueeze(0).transpose(0, -1) actual = dim_group_a.matmul(dim_group_a.transpose(-2, -1)) res = kernel(a, a, batch_dims=(0, 2)).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # batch_dims + diag res = kernel(a, a, batch_dims=(0, 2)).diag() actual = actual.diagonal(dim1=-2, dim2=-1) self.assertLess(torch.norm(res - actual), 1e-4)
def test_computes_linear_function_square(self): a = torch.tensor([[4, 1], [2, 0], [8, 3]], dtype=torch.float) kernel = LinearKernel().initialize(variance=3.14) kernel.eval() actual = torch.matmul(a, a.t()) * 3.14 res = kernel(a, a).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # diag res = kernel(a, a).diag() actual = actual.diag() self.assertLess(torch.norm(res - actual), 1e-4) # batch_dims dim_group_a = a dim_group_a = dim_group_a.permute(1, 0).contiguous().view(-1, 3) actual = 3.14 * torch.mul(dim_group_a.unsqueeze(-1), dim_group_a.unsqueeze(-2)) res = kernel(a, a, batch_dims=(0, 2)).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # batch_dims + diag res = kernel(a, a, batch_dims=(0, 2)).diag() actual = torch.cat( [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-4)
def test_computes_linear_function_square_batch(self): a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]], dtype=torch.float) offset = torch.randn(1, 1, 2) kernel = LinearKernel(num_dimensions=2).initialize(offset=offset, variance=1.0) kernel.eval() actual = 1 + torch.matmul(a - offset, (a - offset).transpose(-1, -2)) res = kernel(a, a).evaluate() self.assertLess(torch.norm(res - actual), 1e-5) # diag res = kernel(a, a).diag() actual = torch.cat( [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) # batch_dims dim_group_a = a - offset dim_group_a = dim_group_a.permute(0, 2, 1).contiguous().view(-1, 3) actual = 1 + torch.mul(dim_group_a.unsqueeze(-1), dim_group_a.unsqueeze(-2)) res = kernel(a, a, batch_dims=(0, 2)).evaluate() self.assertLess(torch.norm(res - actual), 1e-5) # batch_dims + diag res = kernel(a, a, batch_dims=(0, 2)).diag() actual = torch.cat( [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5)
def test_computes_linear_function_square(self): a = torch.Tensor([4, 2, 8]).view(3, 1) kernel = LinearKernel(num_dimensions=1).initialize(offset=0, variance=1.0) kernel.eval() actual = 1 + torch.matmul(a, a.t()) res = kernel(Variable(a), Variable(a)).evaluate() self.assertLess(torch.norm(res.data - actual), 1e-5)
def test_computes_linear_function_rectangular(self): a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) b = torch.tensor([0, 2], dtype=torch.float).view(2, 1) kernel = LinearKernel(num_dimensions=1).initialize(offset=0, variance=1.0) kernel.eval() actual = 1 + torch.matmul(a, b.t()) res = kernel(a, b).evaluate() self.assertLess(torch.norm(res - actual), 1e-5)
def test_computes_linear_function_rectangular(self): a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) b = torch.tensor([0, 2, 1], dtype=torch.float).view(3, 1) kernel = LinearKernel().initialize(variance=1.0) kernel.eval() actual = torch.matmul(a, b.t()) res = kernel(a, b).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # diag res = kernel(a, b).diag() actual = actual.diag() self.assertLess(torch.norm(res - actual), 1e-4)