def apply_transform(  # type: ignore
        self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor]  # type: ignore
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        height, width = input.size(2), input.size(3)
        num_mixes = params['mix_pairs'].size(0)
        batch_size = params['mix_pairs'].size(1)

        _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs')
        _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src')

        out_inputs = input.clone()
        out_labels = []
        for pair, crop in zip(params['mix_pairs'], params['crop_src']):
            input_permute = input.index_select(dim=0, index=pair.to(input.device))
            labels_permute = label.index_select(dim=0, index=pair.to(label.device))
            w, h = infer_bbox_shape(crop)
            lam = w.to(input.dtype) * h.to(input.dtype) / (width * height)  # width_beta * height_beta
            # compute mask to match input shape
            mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat(1, input.size(1), 1, 1)
            out_inputs[mask] = input_permute[mask]
            out_labels.append(
                torch.stack([label.to(input.dtype), labels_permute.to(input.dtype), lam.to(label.device)], dim=1)
            )

        return out_inputs, torch.stack(out_labels, dim=0)
Exemple #2
0
 def test_bounding_boxes_dim_inferring(self, device, dtype):
     boxes = torch.tensor(
         [[[1.0, 1.0], [3.0, 1.0], [3.0, 2.0], [1.0, 2.0]]],
         device=device,
         dtype=dtype)
     h, w = infer_bbox_shape(boxes)
     assert (h, w) == (2, 3)
Exemple #3
0
 def test_bounding_boxes_dim_inferring_batch(self, device, dtype):
     boxes = torch.tensor(
         [[[1.0, 1.0], [3.0, 1.0], [3.0, 2.0], [1.0, 2.0]], [[2.0, 2.0], [4.0, 2.0], [4.0, 3.0], [2.0, 3.0]]],
         device=device,
         dtype=dtype,
     )
     h, w = infer_bbox_shape(boxes)
     assert (h.unique().item(), w.unique().item()) == (2, 3)
Exemple #4
0
def crop_by_indices(
    input: torch.Tensor,
    src_box: torch.Tensor,
    size: Optional[Tuple] = None,
    interpolation: str = 'bilinear',
    align_corners: Optional[bool] = None,
    antialias: bool = False,
) -> torch.Tensor:
    """Crop tensors with naive indices.

    Args:
        input: the 2D image tensor with shape (B, C, H, W).
        src_box: a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        size: output size. An auto resize will be performed if the cropped slice sizes are not exactly align `size`.
            If None, will auto-infer from src_box.
        interpolation:  algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` |
            'bicubic' | 'trilinear' | 'area'.
        align_corners: interpolation flag.
        antialias: if True, then image will be filtered with Gaussian before downscaling.
            No effect for upscaling.
    """
    B, C, _, _ = input.shape
    src = torch.as_tensor(src_box, device=input.device, dtype=torch.long)
    x1 = src[:, 0, 0]
    x2 = src[:, 1, 0] + 1
    y1 = src[:, 0, 1]
    y2 = src[:, 3, 1] + 1

    if (
        len(x1.unique(sorted=False)) == len(x2.unique(sorted=False)) == len(
            y1.unique(sorted=False)) == len(y2.unique(sorted=False)) == 1
    ):
        out = input[..., y1[0]:y2[0], x1[0]:x2[0]]  # type:ignore
        if size is not None and out.shape[-2:] != size:
            return resize(
                out,
                size,
                interpolation=interpolation,
                align_corners=align_corners,
                side="short",
                antialias=antialias
            )

    if size is None:
        h, w = infer_bbox_shape(src)
        size = h.unique(sorted=False), w.unique(sorted=False)
    out = torch.empty(B, C, *size, device=input.device, dtype=input.dtype)
    # Find out the cropped shapes that need to be resized.
    shape_list = torch.stack([y2 - y1, x2 - x1], dim=-1)
    _size = torch.as_tensor(size, device=shape_list.device, dtype=shape_list.dtype)
    same_sized = (shape_list == _size).all(-1)
    for i, same in enumerate(same_sized):
        if not same:
            out[i] = resize(
                input[i:i + 1, :, y1[i]:y2[i], x1[i]:x2[i]],  # type:ignore
                size,
                interpolation=interpolation,
                align_corners=align_corners,
                side="short",
                antialias=antialias
            )
        else:
            out[i] = input[i:i + 1, :, y1[i]:y2[i], x1[i]:x2[i]]  # type:ignore
    return out
Exemple #5
0
def crop_by_boxes(
    tensor: torch.Tensor,
    src_box: torch.Tensor,
    dst_box: torch.Tensor,
    mode: str = 'bilinear',
    padding_mode: str = 'zeros',
    align_corners: bool = True,
    validate_boxes: bool = True
) -> torch.Tensor:
    """Perform crop transform on 2D images (4D tensor) given two bounding boxes.

    Given an input tensor, this function selected the interested areas by the provided bounding boxes (src_box).
    Then the selected areas would be fitted into the targeted bounding boxes (dst_box) by a perspective transformation.
    So far, the ragged tensor is not supported by PyTorch right now. This function hereby requires the bounding boxes
    in a batch must be rectangles with same width and height.

    Args:
        tensor: the 2D image tensor with shape (B, C, H, W).
        src_box: a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be extracted. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        dst_box: a tensor with shape (B, 4, 2) containing the coordinates of the bounding boxes
            to be placed. The tensor must have the shape of Bx4x2, where each box is defined in the clockwise
            order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in x, y order.
        mode: interpolation mode to calculate output values
          ``'bilinear'`` | ``'nearest'``.
        padding_mode: padding mode for outside grid values
          ``'zeros'`` | ``'border'`` | ``'reflection'``.
        align_corners: mode for grid_generation.
        validate_boxes: flag to perform validation on boxes.

    Returns:
        torch.Tensor: the output tensor with patches.

    Examples:
        >>> input = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))
        >>> src_box = torch.tensor([[
        ...     [1., 1.],
        ...     [2., 1.],
        ...     [2., 2.],
        ...     [1., 2.],
        ... ]])  # 1x4x2
        >>> dst_box = torch.tensor([[
        ...     [0., 0.],
        ...     [1., 0.],
        ...     [1., 1.],
        ...     [0., 1.],
        ... ]])  # 1x4x2
        >>> crop_by_boxes(input, src_box, dst_box, align_corners=True)
        tensor([[[[ 5.0000,  6.0000],
                  [ 9.0000, 10.0000]]]])

    Note:
        If the src_box is smaller than dst_box, the following error will be thrown.
        RuntimeError: solve_cpu: For batch 0: U(2,2) is zero, singular U.
    """
    if validate_boxes:
        validate_bbox(src_box)
        validate_bbox(dst_box)

    if len(tensor.shape) != 4:
        raise AssertionError(f"Only tensor with shape (B, C, H, W) supported. Got {tensor.shape}.")

    # compute transformation between points and warp
    # Note: Tensor.dtype must be float. "solve_cpu" not implemented for 'Long'
    dst_trans_src: torch.Tensor = get_perspective_transform(src_box.to(tensor), dst_box.to(tensor))

    bbox: Tuple[torch.Tensor, torch.Tensor] = infer_bbox_shape(dst_box)
    if not ((bbox[0] == bbox[0][0]).all() and (bbox[1] == bbox[1][0]).all()):
        raise AssertionError(
            f"Cropping height, width and depth must be exact same in a batch. "
            f"Got height {bbox[0]} and width {bbox[1]}."
        )

    h_out: int = int(bbox[0][0].item())
    w_out: int = int(bbox[1][0].item())

    return crop_by_transform_mat(
        tensor, dst_trans_src, (h_out, w_out), mode=mode, padding_mode=padding_mode, align_corners=align_corners
    )