def __init__(self, input1_size, input2_size, output_size, h1 = None, s1 = None, h2 = None, s2 = None): super(CompactBilinearPooling, self).__init__() self.add_module('sketch1', CountSketch(input1_size, output_size, h1, s1)) self.add_module('sketch2', CountSketch(input2_size, output_size, h2, s2)) self.fft = afft.Rfft() self.fft2 = afft.Rfft() self.ifft = afft.Irfft() self.output_size = output_size
def backward(ctx, grad_output): h1, s1, h2, s2, x, y = ctx.saved_tensors # Recompute part of the forward pass to get the input to the complex product # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add) py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add) # Then convert the output to Fourier domain grad_output = grad_output.contiguous() grad_re_prod, grad_im_prod = afft.Rfft()(grad_output) # Compute the gradient of x first then y # Gradient of x # Recompute fy re_fy, im_fy = fft.rfft(py) del py re_fy = Variable(re_fy) im_fy = Variable(im_fy) # Compute the gradient of fx, then back to temporal space grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod, im_fy) grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy) grad_fx = afft.Irfft()(grad_re_fx, grad_im_fx) # Finally compute the gradient of x grad_x = CountSketchFn_backward(Variable(h1), Variable(s1), ctx.x_size, grad_fx) del re_fy, im_fy, grad_re_fx, grad_im_fx, grad_fx # Gradient of y # Recompute fx re_fx, im_fx = fft.rfft(px) del px re_fx = Variable(re_fx) im_fx = Variable(im_fx) # Compute the gradient of fy, then back to temporal space grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod, im_fx) grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx) grad_fy = afft.Irfft()(grad_re_fy, grad_im_fy) # Finally compute the gradient of y grad_y = CountSketchFn_backward(Variable(h2), Variable(s2), ctx.y_size, grad_fy) del re_fx, im_fx, grad_re_fy, grad_im_fy, grad_fy return None, None, None, None, None, grad_x, grad_y, None
def test_rfft_gradcheck(): invar = create_real_var(5,10) assert torch.autograd.gradcheck(afft.Rfft(), invar) invar = create_real_var(5,11) assert torch.autograd.gradcheck(afft.Rfft(), invar)