def test_computes_linear_function_square_batch(self):
        a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]],
                         dtype=torch.float)

        kernel = LinearKernel().initialize(variance=1.0)
        kernel.eval()
        actual = torch.matmul(a, a.transpose(-1, -2))
        res = kernel(a, a).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-4)

        # diag
        res = kernel(a, a).diag()
        actual = torch.cat(
            [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
        self.assertLess(torch.norm(res - actual), 1e-4)

        # batch_dims
        dim_group_a = a
        dim_group_a = dim_group_a.unsqueeze(0).transpose(0, -1)
        actual = dim_group_a.matmul(dim_group_a.transpose(-2, -1))
        res = kernel(a, a, batch_dims=(0, 2)).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-4)

        # batch_dims + diag
        res = kernel(a, a, batch_dims=(0, 2)).diag()
        actual = actual.diagonal(dim1=-2, dim2=-1)
        self.assertLess(torch.norm(res - actual), 1e-4)
    def test_computes_linear_function_square(self):
        a = torch.tensor([[4, 1], [2, 0], [8, 3]], dtype=torch.float)

        kernel = LinearKernel().initialize(variance=3.14)
        kernel.eval()
        actual = torch.matmul(a, a.t()) * 3.14
        res = kernel(a, a).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-4)

        # diag
        res = kernel(a, a).diag()
        actual = actual.diag()
        self.assertLess(torch.norm(res - actual), 1e-4)

        # batch_dims
        dim_group_a = a
        dim_group_a = dim_group_a.permute(1, 0).contiguous().view(-1, 3)
        actual = 3.14 * torch.mul(dim_group_a.unsqueeze(-1),
                                  dim_group_a.unsqueeze(-2))
        res = kernel(a, a, batch_dims=(0, 2)).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-4)

        # batch_dims + diag
        res = kernel(a, a, batch_dims=(0, 2)).diag()
        actual = torch.cat(
            [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
        self.assertLess(torch.norm(res - actual), 1e-4)
Example #3
0
    def test_computes_linear_function_square_batch(self):
        a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]],
                         dtype=torch.float)

        offset = torch.randn(1, 1, 2)
        kernel = LinearKernel(num_dimensions=2).initialize(offset=offset,
                                                           variance=1.0)
        kernel.eval()
        actual = 1 + torch.matmul(a - offset, (a - offset).transpose(-1, -2))
        res = kernel(a, a).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)

        # diag
        res = kernel(a, a).diag()
        actual = torch.cat(
            [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
        self.assertLess(torch.norm(res - actual), 1e-5)

        # batch_dims
        dim_group_a = a - offset
        dim_group_a = dim_group_a.permute(0, 2, 1).contiguous().view(-1, 3)
        actual = 1 + torch.mul(dim_group_a.unsqueeze(-1),
                               dim_group_a.unsqueeze(-2))
        res = kernel(a, a, batch_dims=(0, 2)).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)

        # batch_dims + diag
        res = kernel(a, a, batch_dims=(0, 2)).diag()
        actual = torch.cat(
            [actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
        self.assertLess(torch.norm(res - actual), 1e-5)
    def test_computes_linear_function_square(self):
        a = torch.Tensor([4, 2, 8]).view(3, 1)

        kernel = LinearKernel(num_dimensions=1).initialize(offset=0, variance=1.0)
        kernel.eval()
        actual = 1 + torch.matmul(a, a.t())
        res = kernel(Variable(a), Variable(a)).evaluate()
        self.assertLess(torch.norm(res.data - actual), 1e-5)
    def test_computes_linear_function_rectangular(self):
        a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
        b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)

        kernel = LinearKernel(num_dimensions=1).initialize(offset=0,
                                                           variance=1.0)
        kernel.eval()
        actual = 1 + torch.matmul(a, b.t())
        res = kernel(a, b).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-5)
    def test_computes_linear_function_rectangular(self):
        a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
        b = torch.tensor([0, 2, 1], dtype=torch.float).view(3, 1)

        kernel = LinearKernel().initialize(variance=1.0)
        kernel.eval()
        actual = torch.matmul(a, b.t())
        res = kernel(a, b).evaluate()
        self.assertLess(torch.norm(res - actual), 1e-4)

        # diag
        res = kernel(a, b).diag()
        actual = actual.diag()
        self.assertLess(torch.norm(res - actual), 1e-4)