def apply_rotation3d(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(B, C, H, W)`. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). params (Dict[str, torch.Tensor]): - params['degrees']: degree to be applied. flags (Dict[str, torch.Tensor]): - params['resample']: Integer tensor. NEAREST = 0, BILINEAR = 1. - params['align_corners']: Boolean tensor. Returns: torch.Tensor: The cropped input """ input = _transform_input3d(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) yaw: torch.Tensor = params["yaw"].type_as(input) pitch: torch.Tensor = params["pitch"].type_as(input) roll: torch.Tensor = params["roll"].type_as(input) resample_mode: str = Resample(flags['resample'].item()).name.lower() align_corners: bool = cast(bool, flags['align_corners'].item()) transformed: torch.Tensor = rotate3d(input, yaw, pitch, roll, mode=resample_mode, align_corners=align_corners) return transformed
def apply_rotation3d(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Args: input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, D, H, W)`. params (Dict[str, torch.Tensor]): - params['degrees']: degree to be applied. flags (Dict[str, torch.Tensor]): - params['resample']: Integer tensor. NEAREST = 0, BILINEAR = 1. - params['align_corners']: Boolean tensor. Returns: torch.Tensor: The cropped input. """ yaw: torch.Tensor = params["yaw"].to(input) pitch: torch.Tensor = params["pitch"].to(input) roll: torch.Tensor = params["roll"].to(input) resample_mode: str = Resample(flags['resample'].item()).name.lower() align_corners: bool = cast(bool, flags['align_corners'].item()) transformed: torch.Tensor = rotate3d(input, yaw, pitch, roll, mode=resample_mode, align_corners=align_corners) return transformed