コード例 #1
0
def my_tensordot2(x, y):
    xshape, yshape = x.shape[1:], y.shape[1:]
    f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))
                         ).keops_tensordot(LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
                                           xshape,
                                           yshape,
                                           (2, 0),  # (2,0,1),
                                           (0, 1)  # (0,3,2)
                                           )
    return f_keops.sum_reduction(dim=1)
コード例 #2
0
ファイル: unit_tests_pytorch.py プロジェクト: getkeops/keops
    def test_TensorDot_with_permute(self):
        ############################################################
        import torch
        from pykeops.torch import LazyTensor

        def my_tensordort_perm(a, b, dims=None, perm=None):
            return torch.tensordot(a, b, dims=dims).sum(3).permute(perm)

        def invert_permutation_numpy(permutation):
            return np.arange(len(permutation))[np.argsort(permutation)]

        x = torch.randn(self.M, 2, 3, 2, 2, 4, requires_grad=True, dtype=torch.float64)
        y = torch.randn(
            self.N, 2, 4, 2, 3, 2, 3, requires_grad=True, dtype=torch.float64
        )

        dimfa, dimfb = x.shape[1:], y.shape[1:]
        contfa, contfb = [5, 1, 3], [2, 5, 3]
        perm = [4, 3, 2, 0, 1]
        perm_torch = (0,) + tuple([(i + 1) for i in invert_permutation_numpy(perm)])
        sum_f_torch2 = my_tensordort_perm(x, y, dims=(contfa, contfb), perm=perm_torch)

        f_keops = LazyTensor(
            x.reshape(self.M, 1, int(np.array((dimfa)).prod()))
        ).keops_tensordot(
            LazyTensor(y.reshape(1, self.N, int(np.array(dimfb).prod()))),
            dimfa,
            dimfb,
            tuple(np.array(contfa) - 1),
            tuple(np.array(contfb) - 1),
            tuple(perm),
        )
        sum_f_keops = f_keops.sum_reduction(dim=1)
        self.assertTrue(torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

        e = torch.randn_like(sum_f_torch2)
        # checking gradients
        grad_keops = torch.autograd.grad(
            sum_f_keops, x, e.reshape(self.M, -1), retain_graph=True
        )[0]
        grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0]
        self.assertTrue(
            torch.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4)
        )

        grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(self.M, -1))[0]
        grad_torch = torch.autograd.grad(sum_f_torch2, y, e)[0]
        self.assertTrue(
            torch.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4)
        )
コード例 #3
0
# Then, a reduction is performed alog the dimension of size N.

x = torch.randn(M, 4, 7, 3, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 7, 2, requires_grad=True, dtype=torch.float64)

f_torch = torch.tensordot(x, y, dims=([2], [1]))  # now is shape (M, 4, 3, N, 2)
sum_f_torch2 = f_torch.sum(3)  # ... yielding a result of dimension (M,4*3*2)

# In KeOps, we forgot the first reduction axis (size M and N respectively). We then need to tell the compiler not only
# the contration axis (1 and 0 respectively both of dimension 7) but the shapes (4,7,3) and (7,2) as well,
# keeping in mind that the 2 actual first axis of x and y (reduction axis) are ignored so the result has
# shape (M,4*3*2) or (N, 4*3*2) depending on the chosen reduction axis.

f_keops = LazyTensor(x.reshape(M, 1, 4 * 7 * 3)).keops_tensordot(LazyTensor(y.reshape(1, N, 7 * 2)), (4, 7, 3), (7, 2),
                                                                 (1,), (0,))
sum_f_keops = f_keops.sum_reduction(dim=1)  # reduction is perform along second axis
# print(sum_f_keops.flatten())                        # ... yielding a result of dimension (M,4*3*2)

print("Compare the two tensordot implementation. All good ?",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten(), rtol=1e-4))

########################################################################################################################
# As before, let us check the gradients

e = torch.randn(M, 4 * 3 * 2, dtype=torch.float64)
Ee = e.reshape(M, 4, 3, 2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e, retain_graph=True)[0].squeeze().numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, Ee, retain_graph=True)[0].squeeze().numpy()

# print(grad_keops[0,:,:,:])
# print(grad_torch[0,:,:,:])