def crop_by_boxes(tensor, src_box, dst_box,
                  interpolation: str = 'bilinear',
                  align_corners: bool = False) -> torch.Tensor:
    """A wrapper performs crop transform with bounding boxes.

        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
    if tensor.ndimension() not in [3, 4]:
        raise TypeError("Only tensor with shape (C, H, W) and (B, C, H, W) supported. Got %s" % str(tensor.shape))
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor.dtype), dst_box.to(tensor.dtype))
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1).type_as(tensor)

    bbox = _infer_bounding_box(dst_box)
    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :], (int(bbox[0].int().data.item()), int(bbox[1].int().data.item())),
        flags=interpolation, align_corners=align_corners)

    # return in the original shape
    if is_unbatched:
        patches = torch.squeeze(patches, dim=0)

    return patches
def affine(tensor: torch.Tensor,
           matrix: torch.Tensor,
           mode: str = 'bilinear') -> torch.Tensor:
    r"""Apply an affine transformation to the image.

        tensor (torch.Tensor): The image tensor to be warped.
        matrix (torch.Tensor): The 2x3 affine transformation matrix.
        mode (str): 'bilinear' | 'nearest'

        torch.Tensor: The warped image.
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # we enforce broadcasting since by default grid_sample it does not
    # give support for that
    matrix = matrix.expand(tensor.shape[0], -1, -1)

    # warp the input tensor
    height: int = tensor.shape[-2]
    width: int = tensor.shape[-1]
    warped: torch.Tensor = warp_affine(tensor, matrix, (height, width), mode)

    # return in the original shape
    if is_unbatched:
        warped = torch.squeeze(warped, dim=0)

    return warped
def affine(tensor: torch.Tensor, matrix: torch.Tensor, mode: str = 'bilinear',
           align_corners: bool = False) -> torch.Tensor:
    r"""Apply an affine transformation to the image.

        tensor (torch.Tensor): The image tensor to be warped in shapes of
            :math:`(H, W)`, :math:`(D, H, W)` and :math:`(B, C, H, W)`.
        matrix (torch.Tensor): The 2x3 affine transformation matrix.
        mode (str): 'bilinear' | 'nearest'
        align_corners(bool): interpolation flag. Default: False. See
        https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail

        torch.Tensor: The warped image.
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # we enforce broadcasting since by default grid_sample it does not
    # give support for that
    matrix = matrix.expand(tensor.shape[0], -1, -1)

    # warp the input tensor
    height: int = tensor.shape[-2]
    width: int = tensor.shape[-1]
    warped: torch.Tensor = warp_affine(tensor, matrix, (height, width), mode,

    # return in the original shape
    if is_unbatched:
        warped = torch.squeeze(warped, dim=0)

    return warped
def affine(tensor: torch.Tensor,
           matrix: torch.Tensor,
           mode: str = 'bilinear',
           align_corners: bool = False) -> torch.Tensor:

    # warping needs data in the shape of BCHW
    is_unbatched = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # we enforce broadcasting since by default grid_sample it does not
    # give support for that
    matrix = matrix.expand(tensor.shape[0], -1, -1)

    # warp the input tensor
    height = tensor.shape[-2]
    width = tensor.shape[-1]
    warped = warp_affine(tensor,
                         matrix, (height, width),

    # return in the original shape
    if is_unbatched:
        warped = torch.squeeze(warped, dim=0)

    return warped
def crop_by_transform_mat(
        tensor: torch.Tensor,
        transform: torch.Tensor,
        out_size: Tuple[int, int],
        mode: str = 'bilinear',
        padding_mode: str = 'zeros',
        align_corners: Optional[bool] = None) -> torch.Tensor:
    """Perform crop transform on 2D images (4D tensor) given a perspective transformation matrix.

        tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W).
        transform (torch.Tensor): a perspective transformation matrix with shape (B, 3, 3).
        out_size (Tuple[int, int]): size of the output image (height, width).
        mode (str): interpolation mode to calculate output values
          'bilinear' | 'nearest'. Default: 'bilinear'.
        padding_mode (str): padding mode for outside grid values
          'zeros' | 'border' | 'reflection'. Default: 'zeros'.
        align_corners (bool, optional): mode for grid_generation. Default: None.

        torch.Tensor: the output tensor with patches.
    # simulate broadcasting
    dst_trans_src = torch.as_tensor(transform.expand(tensor.shape[0], -1, -1),

    patches: torch.Tensor = warp_affine(tensor,
                                        dst_trans_src[:, :2, :],

    return patches
def crop_tensor(image,
    ''' for batch image
        image (torch.Tensor): the reference tensor of shape BXHxWXC.
        center: [bz, 2]
        bboxsize: [bz, 1]
        interpolation (str): Interpolation flag. Default: 'bilinear'.
        align_corners (bool): mode for grid_generation. Default: False. See
          https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
    dtype = image.dtype
    device = image.device
    batch_size = image.shape[0]
    # points: top-left, top-right, bottom-right, bottom-left
    src_pts = torch.zeros([4, 2], dtype=dtype,
                              batch_size, -1, -1).contiguous()

    src_pts[:, 0, :] = center - bbox_size * 0.5  # / (self.crop_size - 1)
    src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
    src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
    src_pts[:, 2, :] = center + bbox_size * 0.5
    src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
    src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5

    DST_PTS = torch.tensor([[
        [0, 0],
        [crop_size - 1, 0],
        [crop_size - 1, crop_size - 1],
        [0, crop_size - 1],
                           device=device).expand(batch_size, -1, -1)
    # estimate transformation between points
    dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
    # simulate broadcasting
    # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)

    # warp images
    cropped_image = warp_affine(image,
                                dst_trans_src[:, :2, :],
                                (crop_size, crop_size),

    tform = torch.transpose(dst_trans_src, 2, 1)
    # tform = torch.inverse(dst_trans_src)
    return cropped_image, tform
def affine(
    tensor: torch.Tensor,
    matrix: torch.Tensor,
    mode: str = 'bilinear',
    padding_mode: str = 'zeros',
    align_corners: Optional[bool] = None,
) -> torch.Tensor:
    r"""Apply an affine transformation to the image.

        tensor (torch.Tensor): The image tensor to be warped in shapes of
            :math:`(H, W)`, :math:`(D, H, W)` and :math:`(B, C, H, W)`.
        matrix (torch.Tensor): The 2x3 affine transformation matrix.
        mode (str): interpolation mode to calculate output values
          'bilinear' | 'nearest'. Default: 'bilinear'.
        padding_mode (str): padding mode for outside grid values
          'zeros' | 'border' | 'reflection'. Default: 'zeros'.
        align_corners(bool, optional): interpolation flag. Default: None.

        torch.Tensor: The warped image with the same shape as the input.

        >>> img = torch.rand(1, 2, 3, 5)
        >>> aff = torch.eye(2, 3)[None]
        >>> out = affine(img, aff)
        >>> print(out.shape)
        torch.Size([1, 2, 3, 5])
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # we enforce broadcasting since by default grid_sample it does not
    # give support for that
    matrix = matrix.expand(tensor.shape[0], -1, -1)

    # warp the input tensor
    height: int = tensor.shape[-2]
    width: int = tensor.shape[-1]
    warped: torch.Tensor = warp_affine(tensor, matrix, (height, width), mode,
                                       padding_mode, align_corners)

    # return in the original shape
    if is_unbatched:
        warped = torch.squeeze(warped, dim=0)

    return warped
def crop_by_boxes(
    return_transform: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """A wrapper performs crop transform with bounding boxes.

    if tensor.ndimension() not in [3, 4]:
        raise TypeError(
            "Only tensor with shape (C, H, W) and (B, C, H, W) supported. Got %s"
            % str(tensor.shape))
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # compute transformation between points and warp
    dst_trans_src: torch.Tensor = get_perspective_transform(
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1)

    bbox = _infer_bounding_box(dst_box)
    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :],
        (int(bbox[0].int().data.item()), int(bbox[1].int().data.item())))

    # return in the original shape
    if is_unbatched:
        patches = torch.squeeze(patches, dim=0)

    if return_transform:
        return patches, dst_trans_src

    return patches
def crop_by_boxes(tensor: torch.Tensor, src_box: torch.Tensor, dst_box: torch.Tensor,
                  interpolation: str = 'bilinear', align_corners: bool = False) -> torch.Tensor:
    """Perform crop transform on 2D images (4D tensor) by bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width and height.

        tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W).
        src_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        dst_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        interpolation (str): Interpolation flag. Default: 'bilinear'.
        align_corners (bool): mode for grid_generation. Default: False. See
          https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details

        torch.Tensor: the output tensor with patches.

        >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))
        >>> src_box = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]])  # 1x4x2
        >>> dst_box = torch.tensor([[
        ...     [0., 0.],
        ...     [1., 0.],
        ...     [1., 1.],
        ...     [0., 1.],
        ... ]])  # 1x4x2
        >>> crop_by_boxes(input, src_box, dst_box, align_corners=True)
        tensor([[[[ 5.0000,  6.0000],
                  [ 9.0000, 10.0000]]]])

        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.

    assert len(tensor.shape) == 4, f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}."

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor.dtype), dst_box.to(tensor.dtype))
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1).type_as(tensor)

    bbox = infer_box_shape(dst_box)
    assert (bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all(), (
        f"Cropping height, width and depth must be exact same in a batch. Got height {bbox[0]} and width {bbox[1]}.")
    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :], (int(bbox[0][0].item()), int(bbox[1][0].item())),
        flags=interpolation, align_corners=align_corners)

    return patches
def crop_by_boxes(tensor: torch.Tensor, src_box: torch.Tensor, dst_box: torch.Tensor,
                  mode: str = 'bilinear', padding_mode: str = 'zeros',
                  align_corners: Optional[bool] = None) -> torch.Tensor:
    """Perform crop transform on 2D images (4D tensor) given two bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width and height.

        tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W).
        src_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        dst_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        mode (str): interpolation mode to calculate output values
          'bilinear' | 'nearest'. Default: 'bilinear'.
        padding_mode (str): padding mode for outside grid values
          'zeros' | 'border' | 'reflection'. Default: 'zeros'.
        align_corners (bool, optional): mode for grid_generation. Default: None.

        torch.Tensor: the output tensor with patches.

        >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))
        >>> src_box = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]])  # 1x4x2
        >>> dst_box = torch.tensor([[
        ...     [0., 0.],
        ...     [1., 0.],
        ...     [1., 1.],
        ...     [0., 1.],
        ... ]])  # 1x4x2
        >>> crop_by_boxes(input, src_box, dst_box, align_corners=True)
        tensor([[[[ 5.0000,  6.0000],
                  [ 9.0000, 10.0000]]]])

        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
    # TODO: improve this since might slow down the function

    assert len(tensor.shape) == 4, f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}."

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(
        src_box.to(tensor), dst_box.to(tensor))

    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1)

    bbox: Tuple[torch.Tensor, torch.Tensor] = infer_box_shape(dst_box)
    assert (bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all(), (
        f"Cropping height, width and depth must be exact same in a batch. "
        f"Got height {bbox[0]} and width {bbox[1]}.")

    h_out: int = int(bbox[0][0].item())
    w_out: int = int(bbox[1][0].item())

    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :], (h_out, w_out),
        mode=mode, padding_mode=padding_mode, align_corners=align_corners)

    return patches