def forward(ctx, h1, s1, h2, s2, output_size, x, y): ctx.save_for_backward(h1,s1,h2,s2,x,y) ctx.x_size = tuple(x.size()) ctx.y_size = tuple(y.size()) ctx.output_size = output_size # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, output_size, x) re_fx,im_fx = fft.rfft(px) del px py = CountSketchFn_forward(h2, s2, output_size, y) re_fy,im_fy = fft.rfft(py) del py # Convolution of the two sketch using an FFT. # Compute the FFT of each sketch # Complex multiplication re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy) # Back to real domain # The imaginary part should be zero's re = fft.irfft(re_prod, im_prod) return re
def backward(ctx, grad_output): h1, s1, h2, s2, x, y = ctx.saved_tensors # Recompute part of the forward pass to get the input to the complex product # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add) py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add) # Then convert the output to Fourier domain grad_output = grad_output.contiguous() grad_re_prod, grad_im_prod = afft.Rfft()(grad_output) # Compute the gradient of x first then y # Gradient of x # Recompute fy re_fy, im_fy = fft.rfft(py) del py re_fy = Variable(re_fy) im_fy = Variable(im_fy) # Compute the gradient of fx, then back to temporal space grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod, im_fy) grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy) grad_fx = afft.Irfft()(grad_re_fx, grad_im_fx) # Finally compute the gradient of x grad_x = CountSketchFn_backward(Variable(h1), Variable(s1), ctx.x_size, grad_fx) del re_fy, im_fy, grad_re_fx, grad_im_fx, grad_fx # Gradient of y # Recompute fx re_fx, im_fx = fft.rfft(px) del px re_fx = Variable(re_fx) im_fx = Variable(im_fx) # Compute the gradient of fy, then back to temporal space grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod, im_fx) grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx) grad_fy = afft.Irfft()(grad_re_fy, grad_im_fy) # Finally compute the gradient of y grad_y = CountSketchFn_backward(Variable(h2), Variable(s2), ctx.y_size, grad_fy) del re_fx, im_fx, grad_re_fy, grad_im_fy, grad_fy return None, None, None, None, None, grad_x, grad_y, None