def my_tensordot2(x, y): xshape, yshape = x.shape[1:], y.shape[1:] f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod())) ).keops_tensordot(LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))), xshape, yshape, (2, 0), # (2,0,1), (0, 1) # (0,3,2) ) return f_keops.sum_reduction(dim=1)
def test_TensorDot_with_permute(self): ############################################################ import torch from pykeops.torch import LazyTensor def my_tensordort_perm(a, b, dims=None, perm=None): return torch.tensordot(a, b, dims=dims).sum(3).permute(perm) def invert_permutation_numpy(permutation): return np.arange(len(permutation))[np.argsort(permutation)] x = torch.randn(self.M, 2, 3, 2, 2, 4, requires_grad=True, dtype=torch.float64) y = torch.randn( self.N, 2, 4, 2, 3, 2, 3, requires_grad=True, dtype=torch.float64 ) dimfa, dimfb = x.shape[1:], y.shape[1:] contfa, contfb = [5, 1, 3], [2, 5, 3] perm = [4, 3, 2, 0, 1] perm_torch = (0,) + tuple([(i + 1) for i in invert_permutation_numpy(perm)]) sum_f_torch2 = my_tensordort_perm(x, y, dims=(contfa, contfb), perm=perm_torch) f_keops = LazyTensor( x.reshape(self.M, 1, int(np.array((dimfa)).prod())) ).keops_tensordot( LazyTensor(y.reshape(1, self.N, int(np.array(dimfb).prod()))), dimfa, dimfb, tuple(np.array(contfa) - 1), tuple(np.array(contfb) - 1), tuple(perm), ) sum_f_keops = f_keops.sum_reduction(dim=1) self.assertTrue(torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten())) e = torch.randn_like(sum_f_torch2) # checking gradients grad_keops = torch.autograd.grad( sum_f_keops, x, e.reshape(self.M, -1), retain_graph=True )[0] grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0] self.assertTrue( torch.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4) ) grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(self.M, -1))[0] grad_torch = torch.autograd.grad(sum_f_torch2, y, e)[0] self.assertTrue( torch.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4) )
# Then, a reduction is performed alog the dimension of size N. x = torch.randn(M, 4, 7, 3, requires_grad=True, dtype=torch.float64) y = torch.randn(N, 7, 2, requires_grad=True, dtype=torch.float64) f_torch = torch.tensordot(x, y, dims=([2], [1])) # now is shape (M, 4, 3, N, 2) sum_f_torch2 = f_torch.sum(3) # ... yielding a result of dimension (M,4*3*2) # In KeOps, we forgot the first reduction axis (size M and N respectively). We then need to tell the compiler not only # the contration axis (1 and 0 respectively both of dimension 7) but the shapes (4,7,3) and (7,2) as well, # keeping in mind that the 2 actual first axis of x and y (reduction axis) are ignored so the result has # shape (M,4*3*2) or (N, 4*3*2) depending on the chosen reduction axis. f_keops = LazyTensor(x.reshape(M, 1, 4 * 7 * 3)).keops_tensordot(LazyTensor(y.reshape(1, N, 7 * 2)), (4, 7, 3), (7, 2), (1,), (0,)) sum_f_keops = f_keops.sum_reduction(dim=1) # reduction is perform along second axis # print(sum_f_keops.flatten()) # ... yielding a result of dimension (M,4*3*2) print("Compare the two tensordot implementation. All good ?", torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten(), rtol=1e-4)) ######################################################################################################################## # As before, let us check the gradients e = torch.randn(M, 4 * 3 * 2, dtype=torch.float64) Ee = e.reshape(M, 4, 3, 2) grad_keops = torch.autograd.grad(sum_f_keops, x, e, retain_graph=True)[0].squeeze().numpy() grad_torch = torch.autograd.grad(sum_f_torch2, x, Ee, retain_graph=True)[0].squeeze().numpy() # print(grad_keops[0,:,:,:]) # print(grad_torch[0,:,:,:])