示例#1
0
    def test_inv_matmul(self):
        base_lazy_tensor_mat = torch.randn(6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.t().matmul(
            base_lazy_tensor_mat)
        test_matrix = torch.randn(3, 4)

        left_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]])
        left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]],
                                          dtype=torch.float)
        left_interp_values_copy = left_interp_values.clone()
        left_interp_values.requires_grad = True
        left_interp_values_copy.requires_grad = True

        right_interp_indices = torch.LongTensor([[2, 3], [3, 4], [4, 5]])
        right_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]],
                                           dtype=torch.float)
        right_interp_values_copy = right_interp_values.clone()
        right_interp_values.requires_grad = True
        right_interp_values_copy.requires_grad = True

        base_lazy_tensor = base_lazy_tensor_mat
        base_lazy_tensor.requires_grad = True
        base_lazy_tensor_copy = base_lazy_tensor_mat
        test_matrix_tensor = test_matrix
        test_matrix_tensor.requires_grad = True
        test_matrix_tensor_copy = test_matrix

        interp_lazy_tensor = InterpolatedLazyTensor(
            NonLazyTensor(base_lazy_tensor),
            left_interp_indices,
            left_interp_values,
            right_interp_indices,
            right_interp_values,
        )
        res = interp_lazy_tensor.inv_matmul(test_matrix_tensor)

        left_matrix = torch.zeros(3, 6)
        right_matrix = torch.zeros(3, 6)
        left_matrix.scatter_(1, left_interp_indices, left_interp_values_copy)
        right_matrix.scatter_(1, right_interp_indices,
                              right_interp_values_copy)
        actual_mat = left_matrix.matmul(base_lazy_tensor_copy).matmul(
            right_matrix.transpose(-1, -2))
        actual = gpytorch.inv_matmul(actual_mat, test_matrix_tensor_copy)

        self.assertTrue(approx_equal(res, actual))

        # Backward pass
        res.sum().backward()
        actual.sum().backward()

        self.assertTrue(
            approx_equal(base_lazy_tensor.grad, base_lazy_tensor_copy.grad))
        self.assertTrue(
            approx_equal(left_interp_values.grad,
                         left_interp_values_copy.grad))
示例#2
0
    def test_inv_matmul_batch(self):
        base_lazy_tensor = torch.randn(6, 6)
        base_lazy_tensor = (
            base_lazy_tensor.t().matmul(base_lazy_tensor)).unsqueeze(0).repeat(
                5, 1, 1)
        base_lazy_tensor_copy = base_lazy_tensor.clone()
        base_lazy_tensor.requires_grad = True
        base_lazy_tensor_copy.requires_grad = True

        test_matrix_tensor = torch.randn(5, 3, 4)
        test_matrix_tensor_copy = test_matrix_tensor.clone()
        test_matrix_tensor.requires_grad = True
        test_matrix_tensor_copy.requires_grad = True

        left_interp_indices = torch.LongTensor([[2, 3], [3, 4],
                                                [4, 5]]).unsqueeze(0).repeat(
                                                    5, 1, 1)
        left_interp_values = torch.tensor(
            [[1, 2], [0.5, 1], [1, 3]],
            dtype=torch.float).unsqueeze(0).repeat(5, 1, 1)
        left_interp_values_copy = left_interp_values.clone()
        left_interp_values.requires_grad = True
        left_interp_values_copy.requires_grad = True

        right_interp_indices = torch.LongTensor([[2, 3], [3, 4],
                                                 [4, 5]]).unsqueeze(0).repeat(
                                                     5, 1, 1)
        right_interp_values = torch.tensor(
            [[1, 2], [0.5, 1], [1, 3]],
            dtype=torch.float).unsqueeze(0).repeat(5, 1, 1)
        right_interp_values_copy = right_interp_values.clone()
        right_interp_values.requires_grad = True
        right_interp_values_copy.requires_grad = True

        interp_lazy_tensor = InterpolatedLazyTensor(
            NonLazyTensor(base_lazy_tensor),
            left_interp_indices,
            left_interp_values,
            right_interp_indices,
            right_interp_values,
        )
        res = interp_lazy_tensor.inv_matmul(test_matrix_tensor)

        left_matrix_comps = []
        right_matrix_comps = []
        for i in range(5):
            left_matrix_comp = torch.zeros(3, 6)
            right_matrix_comp = torch.zeros(3, 6)
            left_matrix_comp.scatter_(1, left_interp_indices[i],
                                      left_interp_values_copy[i])
            right_matrix_comp.scatter_(1, right_interp_indices[i],
                                       right_interp_values_copy[i])
            left_matrix_comps.append(left_matrix_comp.unsqueeze(0))
            right_matrix_comps.append(right_matrix_comp.unsqueeze(0))
        left_matrix = torch.cat(left_matrix_comps)
        right_matrix = torch.cat(right_matrix_comps)
        actual_mat = left_matrix.matmul(base_lazy_tensor_copy).matmul(
            right_matrix.transpose(-1, -2))
        actual = gpytorch.inv_matmul(actual_mat, test_matrix_tensor_copy)

        self.assertTrue(approx_equal(res, actual))

        # Backward pass
        res.sum().backward()
        actual.sum().backward()

        self.assertTrue(
            approx_equal(base_lazy_tensor.grad, base_lazy_tensor_copy.grad))
        self.assertTrue(
            approx_equal(left_interp_values.grad,
                         left_interp_values_copy.grad))