def forward(ctx, h1, s1, h2, s2, output_size, x, y):
        ctx.save_for_backward(h1,s1,h2,s2,x,y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x)
        re_fx,im_fx = fft.rfft(px)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y)
        re_fy,im_fy = fft.rfft(py)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch


        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = fft.irfft(re_prod, im_prod)

        return re
    def backward(ctx, grad_output):
        h1, s1, h2, s2, x, y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x,
                                   ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y,
                                   ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_re_prod, grad_im_prod = afft.Rfft()(grad_output)

        # Compute the gradient of x first then y

        # Gradient of x
        # Recompute fy
        re_fy, im_fy = fft.rfft(py)
        del py
        re_fy = Variable(re_fy)
        im_fy = Variable(im_fy)
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod,
                                   im_fy)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod,
                                   im_fy)
        grad_fx = afft.Irfft()(grad_re_fx, grad_im_fx)
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(Variable(h1), Variable(s1), ctx.x_size,
                                        grad_fx)
        del re_fy, im_fy, grad_re_fx, grad_im_fx, grad_fx

        # Gradient of y
        # Recompute fx
        re_fx, im_fx = fft.rfft(px)
        del px
        re_fx = Variable(re_fx)
        im_fx = Variable(im_fx)
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod,
                                   im_fx)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod,
                                   im_fx)
        grad_fy = afft.Irfft()(grad_re_fy, grad_im_fy)
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(Variable(h2), Variable(s2), ctx.y_size,
                                        grad_fy)
        del re_fx, im_fx, grad_re_fy, grad_im_fy, grad_fy

        return None, None, None, None, None, grad_x, grad_y, None