def compute_rotate_tranformation(input: torch.Tensor, params: Dict[str, torch.Tensor]): r"""Compute the applied transformation matrix :math: `(*, 3, 3)`. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). params (Dict[str, torch.Tensor]): - params['degrees']: degree to be applied. Returns: torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` """ input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) angles: torch.Tensor = params["degrees"].type_as(input) # TODO: This part should be inferred from rotate directly center: torch.Tensor = _compute_tensor_center(input) rotation_mat: torch.Tensor = _compute_rotation_matrix( angles, center.expand(angles.shape[0], -1)) # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix trans_mat: torch.Tensor = torch.eye(3, device=input.device, dtype=input.dtype).repeat( input.shape[0], 1, 1) trans_mat[:, 0] = rotation_mat[:, 0] trans_mat[:, 1] = rotation_mat[:, 1] return trans_mat
def compute_transformation(self, input: Tensor, params: Dict[str, Tensor]) -> Tensor: # TODO: Update to use `get_rotation_matrix2d` angles: Tensor = params["degrees"].to(input) center: Tensor = _compute_tensor_center(input) rotation_mat: Tensor = _compute_rotation_matrix( angles, center.expand(angles.shape[0], -1)) # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix trans_mat: Tensor = torch.eye(3, device=input.device, dtype=input.dtype).repeat( input.shape[0], 1, 1) trans_mat[:, 0] = rotation_mat[:, 0] trans_mat[:, 1] = rotation_mat[:, 1] return trans_mat
def _apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor], return_transform: bool = False): r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: params (dict): A dict that must have {'degrees': torch.Tensor}. Can be generated from kornia.augmentation.param_gen._random_rotation_gen return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation wont be concatenated """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.unsqueeze(0) input = input.view((-1, (*input.shape[-3:]))) angles: torch.Tensor = params["degrees"].to(device, dtype) transformed: torch.Tensor = rotate(input, angles).squeeze(0) if return_transform: center: torch.Tensor = _compute_tensor_center(input) rotation_mat: torch.Tensor = _compute_rotation_matrix( angles, center.expand(angles.shape[0], -1)) # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix trans_mat: torch.Tensor = torch.eye(3, device=device, dtype=dtype).repeat( input.shape[0], 1, 1) trans_mat[:, 0] = rotation_mat[:, 0] trans_mat[:, 1] = rotation_mat[:, 1] return transformed, trans_mat return transformed
def compute_rotate_tranformation(input: torch.Tensor, params: Dict[str, torch.Tensor]): input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) angles: torch.Tensor = params["degrees"].type_as(input) # TODO: This part should be inferred from rotate directly center: torch.Tensor = _compute_tensor_center(input) rotation_mat: torch.Tensor = _compute_rotation_matrix( angles, center.expand(angles.shape[0], -1)) # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix trans_mat: torch.Tensor = torch.eye(3, device=input.device, dtype=input.dtype).repeat( input.shape[0], 1, 1) trans_mat[:, 0] = rotation_mat[:, 0] trans_mat[:, 1] = rotation_mat[:, 1] return trans_mat
def apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor], return_transform: bool = False): r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: params (dict): A dict that must have {'degrees': torch.Tensor}. Can be generated from kornia.augmentation.random_generator.random_rotation_generator return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation wont be concatenated """ input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) angles: torch.Tensor = params["degrees"].type_as(input) transformed: torch.Tensor = rotate( input, angles, mode=Resample(params['interpolation'].item()).name.lower()) if return_transform: # TODO: This part should be inferred from rotate directly center: torch.Tensor = _compute_tensor_center(input) rotation_mat: torch.Tensor = _compute_rotation_matrix( angles, center.expand(angles.shape[0], -1)) # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix trans_mat: torch.Tensor = torch.eye(3, device=input.device, dtype=input.dtype).repeat( input.shape[0], 1, 1) trans_mat[:, 0] = rotation_mat[:, 0] trans_mat[:, 1] = rotation_mat[:, 1] return transformed, trans_mat return transformed