def test_local_homography_warper(self):
        # generate input data
        batch_size = 1
        height, width = 16, 32
        eye_size = 3  # identity 3x3

        # create checkerboard
        board = utils.create_checkerboard(height, width, 4)
        patch_src = torch.from_numpy(board).view(
            1, 1, height, width).expand(batch_size, 1, height, width)

        # create local homography
        dst_homo_src = utils.create_eye_batch(batch_size, eye_size)
        dst_homo_src = dst_homo_src.view(batch_size, -1).unsqueeze(1)
        dst_homo_src = dst_homo_src.repeat(1, height * width, 1).view(
            1, height, width, 3, 3)  # NxHxWx3x3

        # warp reference patch
        patch_src_to_i = tgm.homography_warp(
            patch_src, dst_homo_src, (height, width))
    def test_homography_warper(self, batch_size, device_type):
        # generate input data
        height, width = 128, 64
        eye_size = 3  # identity 3x3
        device = torch.device(device_type)

        # create checkerboard
        board = utils.create_checkerboard(height, width, 4)
        patch_src = torch.from_numpy(board).view(1, 1, height, width).expand(
            batch_size, 1, height, width)
        patch_src = patch_src.to(device)

        # create base homography
        dst_homo_src = utils.create_eye_batch(batch_size, eye_size).to(device)

        # instantiate warper
        warper = tgm.HomographyWarper(height, width)

        for i in range(self.num_tests):
            # generate homography noise
            homo_delta = torch.zeros_like(dst_homo_src)
            homo_delta[:, -1, -1] = 0.0

            dst_homo_src_i = dst_homo_src + homo_delta

            # transform the points from dst to ref
            patch_dst = warper(patch_src, dst_homo_src_i)
            patch_dst_to_src = warper(patch_dst, torch.inverse(dst_homo_src_i))

            # projected should be equal as initial
            error = utils.compute_patch_error(patch_dst, patch_dst_to_src,
                                              height, width)

            assert error.item() < self.threshold

            # check functional api
            patch_dst_to_src_functional = tgm.homography_warp(
                patch_dst, torch.inverse(dst_homo_src_i), (height, width))

            assert utils.check_equal_torch(patch_dst_to_src,
                                           patch_dst_to_src_functional)