def forward(ctx, input, x_min, x_max, y_min, y_max, reparametrization_h,
                reparametrization_w, normalize, exact):

        # store all non-tensor arguments in `ctx`
        ctx.normalize = normalize
        ctx.reparametrization_h = reparametrization_h
        ctx.reparametrization_w = reparametrization_w
        ctx.exact = exact

        x_min, x_max, y_min, y_max = reparametrize(x_min,
                                                   x_max,
                                                   y_min,
                                                   y_max,
                                                   reparametrization_h,
                                                   reparametrization_w,
                                                   inverse=True)

        input_integrated = cpp_cuda.integral_image(input)
        output = cpp_cuda.box_convolution_forward(input_integrated, x_min,
                                                  x_max, y_min, y_max,
                                                  normalize, exact)

        ctx.save_for_backward(input_integrated, x_min, x_max, y_min, y_max,
                              output if normalize else None)

        return output
def test_integral_image(device):
    # or use torch.cumsum
    def integral_image_reference(input):
        assert input.ndimension() >= 2
        h, w = input.shape[-2:]
        output_shape = input.shape[:-2] + (h + 1, w + 1)
        output = torch.empty(output_shape,
                             dtype=input.dtype,
                             device=input.device)

        # zero the 0th columns and rows
        output.select(-2, 0).fill_(0)
        output.select(-1, 0).fill_(0)

        # accumulate rows
        output_no_zero_col = output.narrow(-1, 1, w)
        sum_rows = torch.zeros_like(input.select(-2, 0), dtype=torch.float64)
        for row_idx in range(h):
            sum_rows += input.select(-2, row_idx).double()
            output_no_zero_col.select(-2, row_idx + 1).copy_(sum_rows)

        # accumulate columns
        sum_cols = torch.zeros_like(output.select(-1, 0), dtype=torch.float64)
        for col_idx in range(w):
            sum_cols += output.select(-1, col_idx + 1).double()
            output.select(-1, col_idx + 1).copy_(sum_cols)

        return output

    from box_convolution_cpp_cuda import integral_image

    # check IntegralImageFunction vs reference implementation
    for test_idx in tqdm(range(50)):
        batch_size = random.randint(1, 3)
        in_planes = random.randint(1, 3)
        stride_h, stride_w = 1, 1  # may change in the future
        h, w = (
            random.randint(1 + stride_h, 10),
            random.randint(1 + stride_w, 10),
        )

        input_image = torch.rand(batch_size,
                                 in_planes,
                                 h,
                                 w,
                                 requires_grad=True,
                                 device=device)
        grad_output = torch.rand(batch_size, in_planes, h + 1, w + 1) < 0.15
        grad_output = grad_output.to(device, input_image.dtype)

        reference_result = integral_image_reference(input_image)
        our_result = integral_image(input_image)

        if not our_result.allclose(reference_result):
            raise ValueError(
                "Test %d failed at forward pass.\n\nInput:\n%s\n\n"
                "Our output:\n%s\n\nReference output:\n%s\n\n" %
                (test_idx, input_image, our_result, reference_result))