示例#1
0
def apply_rotation3d(input: torch.Tensor, params: Dict[str, torch.Tensor],
                     flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Rotate a tensor image or a batch of tensor images a random amount of degrees.

    Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(B, C, H, W)`.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['degrees']: degree to be applied.
        flags (Dict[str, torch.Tensor]):
            - params['resample']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: The cropped input
    """
    input = _transform_input3d(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])
    yaw: torch.Tensor = params["yaw"].type_as(input)
    pitch: torch.Tensor = params["pitch"].type_as(input)
    roll: torch.Tensor = params["roll"].type_as(input)

    resample_mode: str = Resample(flags['resample'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    transformed: torch.Tensor = rotate3d(input,
                                         yaw,
                                         pitch,
                                         roll,
                                         mode=resample_mode,
                                         align_corners=align_corners)

    return transformed
示例#2
0
def apply_rotation3d(input: torch.Tensor, params: Dict[str, torch.Tensor],
                     flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Rotate a tensor image or a batch of tensor images a random amount of degrees.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, D, H, W)`.
        params (Dict[str, torch.Tensor]):
            - params['degrees']: degree to be applied.
        flags (Dict[str, torch.Tensor]):
            - params['resample']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: The cropped input.
    """
    yaw: torch.Tensor = params["yaw"].to(input)
    pitch: torch.Tensor = params["pitch"].to(input)
    roll: torch.Tensor = params["roll"].to(input)

    resample_mode: str = Resample(flags['resample'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    transformed: torch.Tensor = rotate3d(input, yaw, pitch, roll, mode=resample_mode, align_corners=align_corners)

    return transformed