def crop_by_boxes(tensor, src_box, dst_box, return_transform: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """A wrapper performs crop transform with bounding boxes. """ if tensor.ndimension() not in [3, 4]: raise TypeError("Only tensor with shape (C, H, W) and (B, C, H, W) supported. Got %s" % str(tensor.shape)) # warping needs data in the shape of BCHW is_unbatched: bool = tensor.ndimension() == 3 if is_unbatched: tensor = torch.unsqueeze(tensor, dim=0) # compute transformation between points and warp dst_trans_src: torch.Tensor = get_perspective_transform( src_box.to(tensor.device).to(tensor.dtype), dst_box.to(tensor.device).to(tensor.dtype) ) # simulate broadcasting dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1) bbox = _infer_bounding_box(dst_box) patches: torch.Tensor = warp_perspective( tensor, dst_trans_src, (int(bbox[0].int().data.item()), int(bbox[1].int().data.item()))) # return in the original shape if is_unbatched: patches = torch.squeeze(patches, dim=0) if return_transform: return patches, dst_trans_src return patches
def crop_by_boxes(tensor, src_box, dst_box, interpolation: str = 'bilinear', align_corners: bool = False) -> torch.Tensor: """A wrapper performs crop transform with bounding boxes. Note: If the src_box is smaller than dst_box, the following error will be thrown. RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U. """ if tensor.ndimension() not in [3, 4]: raise TypeError("Only tensor with shape (C, H, W) and (B, C, H, W) supported. Got %s" % str(tensor.shape)) # warping needs data in the shape of BCHW is_unbatched: bool = tensor.ndimension() == 3 if is_unbatched: tensor = torch.unsqueeze(tensor, dim=0) # compute transformation between points and warp # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long' dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor.dtype), dst_box.to(tensor.dtype)) # simulate broadcasting dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1).type_as(tensor) bbox = _infer_bounding_box(dst_box) patches: torch.Tensor = warp_affine( tensor, dst_trans_src[:, :2, :], (int(bbox[0].int().data.item()), int(bbox[1].int().data.item())), flags=interpolation, align_corners=align_corners) # return in the original shape if is_unbatched: patches = torch.squeeze(patches, dim=0) return patches
def crop_tensor(image, center, bbox_size, crop_size, interpolation='bilinear', align_corners=False): ''' for batch image Args: image (torch.Tensor): the reference tensor of shape BXHxWXC. center: [bz, 2] bboxsize: [bz, 1] crop_size; interpolation (str): Interpolation flag. Default: 'bilinear'. align_corners (bool): mode for grid_generation. Default: False. See https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details Returns: cropped_image tform ''' dtype = image.dtype device = image.device batch_size = image.shape[0] # points: top-left, top-right, bottom-right, bottom-left src_pts = torch.zeros([4, 2], dtype=dtype, device=device).unsqueeze(0).expand( batch_size, -1, -1).contiguous() src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1) src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 src_pts[:, 2, :] = center + bbox_size * 0.5 src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5 src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5 DST_PTS = torch.tensor([[ [0, 0], [crop_size - 1, 0], [crop_size - 1, crop_size - 1], [0, crop_size - 1], ]], dtype=dtype, device=device).expand(batch_size, -1, -1) # estimate transformation between points dst_trans_src = get_perspective_transform(src_pts, DST_PTS) # simulate broadcasting # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1) # warp images cropped_image = warp_affine(image, dst_trans_src[:, :2, :], (crop_size, crop_size), flags=interpolation, align_corners=align_corners) tform = torch.transpose(dst_trans_src, 2, 1) # tform = torch.inverse(dst_trans_src) return cropped_image, tform
def crop_by_boxes(tensor: torch.Tensor, src_box: torch.Tensor, dst_box: torch.Tensor, interpolation: str = 'bilinear', align_corners: bool = False) -> torch.Tensor: """Perform crop transform on 2D images (4D tensor) by bounding boxes. Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box). Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation. So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes in a batch must be rectangles with same width and height. Args: tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W). src_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order. dst_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order. interpolation (str): Interpolation flag. Default: 'bilinear'. align_corners (bool): mode for grid_generation. Default: False. See https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details Returns: torch.Tensor: the output tensor with patches. Examples: >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4)) >>> src_box = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]) # 1x4x2 >>> dst_box = torch.tensor([[ ... [0., 0.], ... [1., 0.], ... [1., 1.], ... [0., 1.], ... ]]) # 1x4x2 >>> crop_by_boxes(input, src_box, dst_box, align_corners=True) tensor([[[[ 5.0000, 6.0000], [ 9.0000, 10.0000]]]]) Note: If the src_box is smaller than dst_box, the following error will be thrown. RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U. """ validate_bboxes(src_box) validate_bboxes(dst_box) assert len(tensor.shape) == 4, f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}." # compute transformation between points and warp # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long' dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor.dtype), dst_box.to(tensor.dtype)) # simulate broadcasting dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1).type_as(tensor) bbox = infer_box_shape(dst_box) assert (bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all(), ( f"Cropping height, width and depth must be exact same in a batch. Got height {bbox[0]} and width {bbox[1]}.") patches: torch.Tensor = warp_affine( tensor, dst_trans_src[:, :2, :], (int(bbox[0][0].item()), int(bbox[1][0].item())), flags=interpolation, align_corners=align_corners) return patches
def crop_by_boxes(tensor: torch.Tensor, src_box: torch.Tensor, dst_box: torch.Tensor, mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: Optional[bool] = None) -> torch.Tensor: """Perform crop transform on 2D images (4D tensor) given two bounding boxes. Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box). Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation. So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes in a batch must be rectangles with same width and height. Args: tensor (torch.Tensor): the 2D image tensor with shape (B, C, H, W). src_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order. dst_box (torch.Tensor): a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order. mode (str): interpolation mode to calculate output values 'bilinear' | 'nearest'. Default: 'bilinear'. padding_mode (str): padding mode for outside grid values 'zeros' | 'border' | 'reflection'. Default: 'zeros'. align_corners (bool, optional): mode for grid_generation. Default: None. Returns: torch.Tensor: the output tensor with patches. Examples: >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4)) >>> src_box = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]) # 1x4x2 >>> dst_box = torch.tensor([[ ... [0., 0.], ... [1., 0.], ... [1., 1.], ... [0., 1.], ... ]]) # 1x4x2 >>> crop_by_boxes(input, src_box, dst_box, align_corners=True) tensor([[[[ 5.0000, 6.0000], [ 9.0000, 10.0000]]]]) Note: If the src_box is smaller than dst_box, the following error will be thrown. RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U. """ # TODO: improve this since might slow down the function validate_bboxes(src_box) validate_bboxes(dst_box) assert len(tensor.shape) == 4, f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}." # compute transformation between points and warp # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long' dst_trans_src: torch.Tensor = get_perspective_transform( src_box.to(tensor), dst_box.to(tensor)) # simulate broadcasting dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1) bbox: Tuple[torch.Tensor, torch.Tensor] = infer_box_shape(dst_box) assert (bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all(), ( f"Cropping height, width and depth must be exact same in a batch. " f"Got height {bbox[0]} and width {bbox[1]}.") h_out: int = int(bbox[0][0].item()) w_out: int = int(bbox[1][0].item()) patches: torch.Tensor = warp_affine( tensor, dst_trans_src[:, :2, :], (h_out, w_out), mode=mode, padding_mode=padding_mode, align_corners=align_corners) return patches
def crop_and_resize(tensor: torch.Tensor, boxes: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: r"""Extracts crops from the input tensor and resizes them. Args: tensor (torch.Tensor): the reference tensor of shape BxCxHxW. boxes (torch.Tensor): a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the following (clockwise) order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in the x, y order. size (Tuple[int, int]): a tuple with the height and width that will be used to resize the extracted patches. Returns: torch.Tensor: tensor containing the patches with shape BxN1xN2 Example: >>> input = torch.tensor([[ [1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.], ]]) >>> boxes = torch.tensor([[ [1., 1.], [1., 2.], [2., 1.], [2., 2.], ]]) # 1x4x2 >>> kornia.crop_and_resize(input, boxes, (2, 2)) tensor([[[ 6.0000, 7.0000], [ 10.0000, 11.0000]]]) """ if not torch.is_tensor(tensor): raise TypeError( "Input tensor type is not a torch.Tensor. Got {}".format( type(tensor))) if not torch.is_tensor(boxes): raise TypeError( "Input boxes type is not a torch.Tensor. Got {}".format( type(boxes))) if not len(tensor.shape) in ( 3, 4, ): raise ValueError("Input tensor must be in the shape of CxHxW or " "BxCxHxW. Got {}".format(tensor.shape)) if not isinstance(size, ( tuple, list, )) and len(size) == 2: raise ValueError( "Input size must be a tuple/list of length 2. Got {}".format(size)) # unpack input data dst_h: torch.Tensor = torch.tensor(size[0]) dst_w: torch.Tensor = torch.tensor(size[1]) # [x, y] origin # top-left, top-right, bottom-right, bottom-left points_src: torch.Tensor = boxes.to(tensor.device).to(tensor.dtype) # [x, y] destination # top-left, top-right, bottom-right, bottom-left points_dst: torch.Tensor = torch.tensor([[ [0, 0], [dst_w - 1, 0], [dst_w - 1, dst_h - 1], [0, dst_h - 1], ]]).repeat(points_src.shape[0], 1, 1).to(tensor.device).to(tensor.dtype) # warping needs data in the shape of BCHW is_unbatched: bool = tensor.ndimension() == 3 if is_unbatched: tensor = torch.unsqueeze(tensor, dim=0) # compute transformation between points and warp dst_trans_src: torch.Tensor = get_perspective_transform( points_src, points_dst) # simulate broadcasting dst_trans_src = dst_trans_src.expand(tensor.shape[0], -1, -1) patches: torch.Tensor = warp_perspective(tensor, dst_trans_src, (int(dst_h), int(dst_w))) # return in the original shape if is_unbatched: patches = torch.squeeze(patches, dim=0) return patches
def center_crop( tensor: torch.Tensor, size: Tuple[int, int], return_transform: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Crops the given tensor at the center. Args: tensor (torch.Tensor): the input tensor with shape (C, H, W) or (B, C, H, W). size (Tuple[int, int]): a tuple with the expected height and width of the output patch. Returns: torch.Tensor: the output tensor with patches. Examples: >>> input = torch.tensor([[ [1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.], ]]) >>> kornia.center_crop(input, (2, 4)) tensor([[[ 5.0000, 6.0000, 7.0000, 8.0000], [ 9.0000, 10.0000, 11.0000, 12.0000]]]) """ if not torch.is_tensor(tensor): raise TypeError( "Input tensor type is not a torch.Tensor. Got {}".format( type(tensor))) if not len(tensor.shape) in ( 3, 4, ): raise ValueError("Input tensor must be in the shape of CxHxW or " "BxCxHxW. Got {}".format(tensor.shape)) if not isinstance(size, ( tuple, list, )) and len(size) == 2: raise ValueError( "Input size must be a tuple/list of length 2. Got {}".format(size)) # unpack input sizes dst_h: torch.Tensor = torch.tensor(size[0]) dst_w: torch.Tensor = torch.tensor(size[1]) src_h: torch.Tensor = torch.tensor(tensor.shape[-2]) src_w: torch.Tensor = torch.tensor(tensor.shape[-1]) # compute start/end offsets dst_h_half: torch.Tensor = dst_h / 2 dst_w_half: torch.Tensor = dst_w / 2 src_h_half: torch.Tensor = src_h / 2 src_w_half: torch.Tensor = src_w / 2 start_x: torch.Tensor = src_h_half - dst_h_half start_y: torch.Tensor = src_w_half - dst_w_half end_x: torch.Tensor = start_x + dst_w - 1 end_y: torch.Tensor = start_y + dst_h - 1 # [y, x] origin # top-left, top-right, bottom-left, bottom-right points_src: torch.Tensor = torch.tensor([[ [start_y, start_x], [start_y, end_x], [end_y, start_x], [end_y, end_x], ]]).to(tensor.device).to(tensor.dtype) # [y, x] destination # top-left, top-right, bottom-left, bottom-right points_dst: torch.Tensor = torch.tensor([[ [0, 0], [0, dst_w - 1], [dst_h - 1, 0], [dst_h - 1, dst_w - 1], ]]).to(tensor.device).to(tensor.dtype) # warping needs data in the shape of BCHW is_unbatched: bool = tensor.ndimension() == 3 if is_unbatched: tensor = torch.unsqueeze(tensor, dim=0) # compute transformation between points and warp dst_trans_src: torch.Tensor = get_perspective_transform( points_src, points_dst) dst_trans_src = dst_trans_src.repeat(tensor.shape[0], 1, 1) patches: torch.Tensor = warp_perspective(tensor, dst_trans_src, (int(dst_h), int(dst_w))) # return in the original shape if is_unbatched: patches = torch.squeeze(patches, dim=0) if return_transform: return patches, dst_trans_src return patches