Beispiel #1
0
 def apply_transform(
     self, input: Tensor, params: Dict[str, Tensor], transform: Optional[Tensor] = None
 ) -> Tensor:
     transform = cast(Tensor, transform)
     return warp_perspective3d(
         input,
         transform,
         (input.shape[-3], input.shape[-2], input.shape[-1]),
         flags=self.flags["resample"].name.lower(),
         align_corners=self.flags["align_corners"],
     )
Beispiel #2
0
 def apply_transform(
     self, input: torch.Tensor, params: Dict[str, torch.Tensor], transform: Optional[torch.Tensor] = None
 ) -> torch.Tensor:
     transform = cast(torch.Tensor, transform)
     return warp_perspective3d(
         input,
         transform,
         (input.shape[-3], input.shape[-2], input.shape[-1]),
         flags=self.resample.name.lower(),
         align_corners=self.align_corners,
     )
Beispiel #3
0
def apply_perspective3d(input: torch.Tensor, params: Dict[str, torch.Tensor],
                        flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Perform perspective transform of the given torch.Tensor or batch of tensors.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
        params (Dict[str, torch.Tensor]):
            - params['start_points']: Tensor containing [top-left, top-right, bottom-right,
              bottom-left] of the orignal image with shape Bx8x3.
            - params['end_points']: Tensor containing [top-left, top-right, bottom-right,
              bottom-left] of the transformed image with shape Bx8x3.
        flags (Dict[str, torch.Tensor]):
            - params['interpolation']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: Perspectively transformed tensor.
    """
    input = _transform_input3d(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    _, _, depth, height, width = input.shape

    # compute the homography between the input points
    transform: torch.Tensor = compute_perspective_transformation3d(
        input, params)

    out_data: torch.Tensor = input.clone()

    # apply the computed transform
    depth, height, width = input.shape[-3:]
    resample_name: str = Resample(flags['interpolation'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    out_data = warp_perspective3d(input,
                                  transform, (depth, height, width),
                                  flags=resample_name,
                                  align_corners=align_corners)

    return out_data.view_as(input)