コード例 #1
0
    def backward(ctx, grad_output):
        h1, s1, h2, s2, x, y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x,
                                   ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y,
                                   ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_re_prod, grad_im_prod = afft.Rfft()(grad_output)

        # Compute the gradient of x first then y

        # Gradient of x
        # Recompute fy
        re_fy, im_fy = fft.rfft(py)
        del py
        re_fy = Variable(re_fy)
        im_fy = Variable(im_fy)
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod,
                                   im_fy)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod,
                                   im_fy)
        grad_fx = afft.Irfft()(grad_re_fx, grad_im_fx)
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(Variable(h1), Variable(s1), ctx.x_size,
                                        grad_fx)
        del re_fy, im_fy, grad_re_fx, grad_im_fx, grad_fx

        # Gradient of y
        # Recompute fx
        re_fx, im_fx = fft.rfft(px)
        del px
        re_fx = Variable(re_fx)
        im_fx = Variable(im_fx)
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod,
                                   im_fx)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod,
                                   im_fx)
        grad_fy = afft.Irfft()(grad_re_fy, grad_im_fy)
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(Variable(h2), Variable(s2), ctx.y_size,
                                        grad_fy)
        del re_fx, im_fx, grad_re_fy, grad_im_fy, grad_fy

        return None, None, None, None, None, grad_x, grad_y, None
 def __init__(self, input1_size, input2_size, output_size, h1 = None, s1 = None, h2 = None, s2 = None):
     super(CompactBilinearPooling, self).__init__()
     self.add_module('sketch1', CountSketch(input1_size, output_size, h1, s1))
     self.add_module('sketch2', CountSketch(input2_size, output_size, h2, s2))
     self.fft = afft.Rfft()
     self.fft2 = afft.Rfft()
     self.ifft = afft.Irfft()
     self.output_size = output_size
コード例 #3
0
def test_irfft_gradcheck():
    invar = create_complex_var(5,11)
    assert torch.autograd.gradcheck(afft.Irfft(), invar)