Example #1
0
def crop_by_boxes(tensor, src_box, dst_box,
                  return_transform: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """A wrapper performs crop transform with bounding boxes.

    """
    if tensor.ndimension() not in [3, 4]:
        raise TypeError("Only tensor with shape (C, H, W) and (B, C, H, W) supported. Got %s" % str(tensor.shape))
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # compute transformation between points and warp
    dst_trans_src: torch.Tensor = get_perspective_transform(
        src_box.to(tensor.device).to(tensor.dtype),
        dst_box.to(tensor.device).to(tensor.dtype)
    )
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1)

    bbox = _infer_bounding_box(dst_box)
    patches: torch.Tensor = warp_perspective(
        tensor, dst_trans_src, (int(bbox[0].int().data.item()), int(bbox[1].int().data.item())))

    # return in the original shape
    if is_unbatched:
        patches = torch.squeeze(patches, dim=0)

    if return_transform:
        return patches, dst_trans_src

    return patches
Example #2
0
def crop_by_boxes(tensor, src_box, dst_box,
                  interpolation: str = 'bilinear',
                  align_corners: bool = False) -> torch.Tensor:
    """A wrapper performs crop transform with bounding boxes.

    Note:
        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
    """
    if tensor.ndimension() not in [3, 4]:
        raise TypeError("Only tensor with shape (C, H, W) and (B, C, H, W) supported. Got %s" % str(tensor.shape))
    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor.dtype), dst_box.to(tensor.dtype))
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1).type_as(tensor)

    bbox = _infer_bounding_box(dst_box)
    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :], (int(bbox[0].int().data.item()), int(bbox[1].int().data.item())),
        flags=interpolation, align_corners=align_corners)

    # return in the original shape
    if is_unbatched:
        patches = torch.squeeze(patches, dim=0)

    return patches
Example #3
0
def crop_tensor(image,
                center,
                bbox_size,
                crop_size,
                interpolation='bilinear',
                align_corners=False):
    ''' for batch image
    Args:
        image (torch.Tensor): the reference tensor of shape BXHxWXC.
        center: [bz, 2]
        bboxsize: [bz, 1]
        crop_size;
        interpolation (str): Interpolation flag. Default: 'bilinear'.
        align_corners (bool): mode for grid_generation. Default: False. See
          https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
    Returns:
        cropped_image
        tform
    '''
    dtype = image.dtype
    device = image.device
    batch_size = image.shape[0]
    # points: top-left, top-right, bottom-right, bottom-left
    src_pts = torch.zeros([4, 2], dtype=dtype,
                          device=device).unsqueeze(0).expand(
                              batch_size, -1, -1).contiguous()

    src_pts[:, 0, :] = center - bbox_size * 0.5  # / (self.crop_size - 1)
    src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
    src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
    src_pts[:, 2, :] = center + bbox_size * 0.5
    src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
    src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5

    DST_PTS = torch.tensor([[
        [0, 0],
        [crop_size - 1, 0],
        [crop_size - 1, crop_size - 1],
        [0, crop_size - 1],
    ]],
                           dtype=dtype,
                           device=device).expand(batch_size, -1, -1)
    # estimate transformation between points
    dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
    # simulate broadcasting
    # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)

    # warp images
    cropped_image = warp_affine(image,
                                dst_trans_src[:, :2, :],
                                (crop_size, crop_size),
                                flags=interpolation,
                                align_corners=align_corners)

    tform = torch.transpose(dst_trans_src, 2, 1)
    # tform = torch.inverse(dst_trans_src)
    return cropped_image, tform
Example #4
0
def crop_by_boxes(tensor: torch.Tensor, src_box: torch.Tensor, dst_box: torch.Tensor,
                  interpolation: str = 'bilinear', align_corners: bool = False) -> torch.Tensor:
    """Perform crop transform on 2D images (4D tensor) by bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width and height.

    Args:
        tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W).
        src_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        dst_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        interpolation (str): Interpolation flag. Default: 'bilinear'.
        align_corners (bool): mode for grid_generation. Default: False. See
          https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details

    Returns:
        torch.Tensor: the output tensor with patches.

    Examples:
        >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))
        >>> src_box = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]])  # 1x4x2
        >>> dst_box = torch.tensor([[
        ...     [0., 0.],
        ...     [1., 0.],
        ...     [1., 1.],
        ...     [0., 1.],
        ... ]])  # 1x4x2
        >>> crop_by_boxes(input, src_box, dst_box, align_corners=True)
        tensor([[[[ 5.0000,  6.0000],
                  [ 9.0000, 10.0000]]]])

    Note:
        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
    """
    validate_bboxes(src_box)
    validate_bboxes(dst_box)

    assert len(tensor.shape) == 4, f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}."

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor.dtype), dst_box.to(tensor.dtype))
    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1).type_as(tensor)

    bbox = infer_box_shape(dst_box)
    assert (bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all(), (
        f"Cropping height, width and depth must be exact same in a batch. Got height {bbox[0]} and width {bbox[1]}.")
    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :], (int(bbox[0][0].item()), int(bbox[1][0].item())),
        flags=interpolation, align_corners=align_corners)

    return patches
Example #5
0
def crop_by_boxes(tensor: torch.Tensor, src_box: torch.Tensor, dst_box: torch.Tensor,
                  mode: str = 'bilinear', padding_mode: str = 'zeros',
                  align_corners: Optional[bool] = None) -> torch.Tensor:
    """Perform crop transform on 2D images (4D tensor) given two bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width and height.

    Args:
        tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W).
        src_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        dst_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        mode (str): interpolation mode to calculate output values
          'bilinear' | 'nearest'. Default: 'bilinear'.
        padding_mode (str): padding mode for outside grid values
          'zeros' | 'border' | 'reflection'. Default: 'zeros'.
        align_corners (bool, optional): mode for grid_generation. Default: None.

    Returns:
        torch.Tensor: the output tensor with patches.

    Examples:
        >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))
        >>> src_box = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]])  # 1x4x2
        >>> dst_box = torch.tensor([[
        ...     [0., 0.],
        ...     [1., 0.],
        ...     [1., 1.],
        ...     [0., 1.],
        ... ]])  # 1x4x2
        >>> crop_by_boxes(input, src_box, dst_box, align_corners=True)
        tensor([[[[ 5.0000,  6.0000],
                  [ 9.0000, 10.0000]]]])

    Note:
        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
    """
    # TODO: improve this since might slow down the function
    validate_bboxes(src_box)
    validate_bboxes(dst_box)

    assert len(tensor.shape) == 4, f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}."

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(
        src_box.to(tensor), dst_box.to(tensor))

    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1)

    bbox: Tuple[torch.Tensor, torch.Tensor] = infer_box_shape(dst_box)
    assert (bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all(), (
        f"Cropping height, width and depth must be exact same in a batch. "
        f"Got height {bbox[0]} and width {bbox[1]}.")

    h_out: int = int(bbox[0][0].item())
    w_out: int = int(bbox[1][0].item())

    patches: torch.Tensor = warp_affine(
        tensor, dst_trans_src[:, :2, :], (h_out, w_out),
        mode=mode, padding_mode=padding_mode, align_corners=align_corners)

    return patches
Example #6
0
def crop_and_resize(tensor: torch.Tensor, boxes: torch.Tensor,
                    size: Tuple[int, int]) -> torch.Tensor:
    r"""Extracts crops from the input tensor and resizes them.

    Args:
        tensor (torch.Tensor): the reference tensor of shape BxCxHxW.
        boxes (torch.Tensor): a tensor containing the coordinates of the
          bounding boxes to be extracted. The tensor must have the shape
          of Bx4x2, where each box is defined in the following (clockwise)
          order: top-left, top-right, bottom-right and bottom-left. The
          coordinates must be in the x, y order.
        size (Tuple[int, int]): a tuple with the height and width that will be
          used to resize the extracted patches.

    Returns:
        torch.Tensor: tensor containing the patches with shape BxN1xN2

    Example:
        >>> input = torch.tensor([[
                [1., 2., 3., 4.],
                [5., 6., 7., 8.],
                [9., 10., 11., 12.],
                [13., 14., 15., 16.],
            ]])
        >>> boxes = torch.tensor([[
                [1., 1.],
                [1., 2.],
                [2., 1.],
                [2., 2.],
            ]])  # 1x4x2
        >>> kornia.crop_and_resize(input, boxes, (2, 2))
        tensor([[[ 6.0000,  7.0000],
                 [ 10.0000, 11.0000]]])
    """
    if not torch.is_tensor(tensor):
        raise TypeError(
            "Input tensor type is not a torch.Tensor. Got {}".format(
                type(tensor)))
    if not torch.is_tensor(boxes):
        raise TypeError(
            "Input boxes type is not a torch.Tensor. Got {}".format(
                type(boxes)))
    if not len(tensor.shape) in (
            3,
            4,
    ):
        raise ValueError("Input tensor must be in the shape of CxHxW or "
                         "BxCxHxW. Got {}".format(tensor.shape))
    if not isinstance(size, (
            tuple,
            list,
    )) and len(size) == 2:
        raise ValueError(
            "Input size must be a tuple/list of length 2. Got {}".format(size))
    # unpack input data
    dst_h: torch.Tensor = torch.tensor(size[0])
    dst_w: torch.Tensor = torch.tensor(size[1])

    # [x, y] origin
    # top-left, top-right, bottom-right, bottom-left
    points_src: torch.Tensor = boxes.to(tensor.device).to(tensor.dtype)

    # [x, y] destination
    # top-left, top-right, bottom-right, bottom-left
    points_dst: torch.Tensor = torch.tensor([[
        [0, 0],
        [dst_w - 1, 0],
        [dst_w - 1, dst_h - 1],
        [0, dst_h - 1],
    ]]).repeat(points_src.shape[0], 1, 1).to(tensor.device).to(tensor.dtype)

    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # compute transformation between points and warp
    dst_trans_src: torch.Tensor = get_perspective_transform(
        points_src, points_dst)

    # simulate broadcasting
    dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1)

    patches: torch.Tensor = warp_perspective(tensor, dst_trans_src,
                                             (int(dst_h), int(dst_w)))

    # return in the original shape
    if is_unbatched:
        patches = torch.squeeze(patches, dim=0)

    return patches
Example #7
0
def center_crop(
    tensor: torch.Tensor,
    size: Tuple[int, int],
    return_transform: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Crops the given tensor at the center.

    Args:
        tensor (torch.Tensor): the input tensor with shape (C, H, W) or
          (B, C, H, W).
        size (Tuple[int, int]): a tuple with the expected height and width
          of the output patch.

    Returns:
        torch.Tensor: the output tensor with patches.

    Examples:
        >>> input = torch.tensor([[
                [1., 2., 3., 4.],
                [5., 6., 7., 8.],
                [9., 10., 11., 12.],
                [13., 14., 15., 16.],
             ]])
        >>> kornia.center_crop(input, (2, 4))
        tensor([[[ 5.0000,  6.0000,  7.0000,  8.0000],
                 [ 9.0000, 10.0000, 11.0000, 12.0000]]])
    """
    if not torch.is_tensor(tensor):
        raise TypeError(
            "Input tensor type is not a torch.Tensor. Got {}".format(
                type(tensor)))

    if not len(tensor.shape) in (
            3,
            4,
    ):
        raise ValueError("Input tensor must be in the shape of CxHxW or "
                         "BxCxHxW. Got {}".format(tensor.shape))

    if not isinstance(size, (
            tuple,
            list,
    )) and len(size) == 2:
        raise ValueError(
            "Input size must be a tuple/list of length 2. Got {}".format(size))

    # unpack input sizes
    dst_h: torch.Tensor = torch.tensor(size[0])
    dst_w: torch.Tensor = torch.tensor(size[1])
    src_h: torch.Tensor = torch.tensor(tensor.shape[-2])
    src_w: torch.Tensor = torch.tensor(tensor.shape[-1])

    # compute start/end offsets
    dst_h_half: torch.Tensor = dst_h / 2
    dst_w_half: torch.Tensor = dst_w / 2
    src_h_half: torch.Tensor = src_h / 2
    src_w_half: torch.Tensor = src_w / 2

    start_x: torch.Tensor = src_h_half - dst_h_half
    start_y: torch.Tensor = src_w_half - dst_w_half

    end_x: torch.Tensor = start_x + dst_w - 1
    end_y: torch.Tensor = start_y + dst_h - 1

    # [y, x] origin
    # top-left, top-right, bottom-left, bottom-right
    points_src: torch.Tensor = torch.tensor([[
        [start_y, start_x],
        [start_y, end_x],
        [end_y, start_x],
        [end_y, end_x],
    ]]).to(tensor.device).to(tensor.dtype)

    # [y, x] destination
    # top-left, top-right, bottom-left, bottom-right
    points_dst: torch.Tensor = torch.tensor([[
        [0, 0],
        [0, dst_w - 1],
        [dst_h - 1, 0],
        [dst_h - 1, dst_w - 1],
    ]]).to(tensor.device).to(tensor.dtype)

    # warping needs data in the shape of BCHW
    is_unbatched: bool = tensor.ndimension() == 3
    if is_unbatched:
        tensor = torch.unsqueeze(tensor, dim=0)

    # compute transformation between points and warp
    dst_trans_src: torch.Tensor = get_perspective_transform(
        points_src, points_dst)
    dst_trans_src = dst_trans_src.repeat(tensor.shape[0], 1, 1)

    patches: torch.Tensor = warp_perspective(tensor, dst_trans_src,
                                             (int(dst_h), int(dst_w)))

    # return in the original shape
    if is_unbatched:
        patches = torch.squeeze(patches, dim=0)

    if return_transform:
        return patches, dst_trans_src

    return patches