def backward(ctx, grad_output): """ In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ I = torch.ifft(grad_output, 2) I[:, :, :, 1] = 0 F = torch.fft(I, 2) #F = grad_output out = fft2vec(F, ctx.apply_scaling) return out
N = 64 tensor = 1000 * torch.randn(B, N, N) #tensor[:,:,:,1] = 0 # no imaginary #tensor = torch.zeros(B, N, N, 2) #for i in range(N): # for j in range(N): # tensor[:,i,j,0] = i*N+j print('tensor', tensor[0, :, :]) tensor_fft = torch.rfft(tensor, 2, onesided=False) print('tensor_fft', tensor_fft[0, :, :, 0]) print('tensor_fft', tensor_fft[0, :, :, 1]) vec = fft2vec(tensor_fft, apply_scaling=True) #print(vec.squeeze()) fft_recovered = vec_fft(vec) print('rec_tensor_fft', fft_recovered[0, :, :, 0]) print('rec_tensor_fft', fft_recovered[0, :, :, 1]) print(torch.abs(tensor_fft - fft_recovered).sum()) print(torch.abs(tensor_fft - fft_recovered).max()) tensor_recovered = torch.irfft(fft_recovered, 2, onesided=False) print('rec_tensor', tensor_recovered[0, :, :])