def apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(B, C, H, W)`. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). params (Dict[str, torch.Tensor]): - params['degrees']: degree to be applied. Returns: torch.Tensor: The cropped input """ input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) angles: torch.Tensor = params["degrees"].type_as(input) resample_mode: str = Resample(params['interpolation'].item()).name.lower() align_corners: bool = cast(bool, params['align_corners'].item()) transformed: torch.Tensor = rotate(input, angles, mode=resample_mode, align_corners=align_corners) return transformed
def apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: params (dict): A dict that must have {'degrees': torch.Tensor}. Can be generated from kornia.augmentation.random_generator.random_rotation_generator return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation wont be concatenated """ input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) angles: torch.Tensor = params["degrees"].type_as(input) resample_mode: str = Resample(params['interpolation'].item()).name.lower() align_corners: bool = cast(bool, params['align_corners'].item()) transformed: torch.Tensor = rotate(input, angles, mode=resample_mode, align_corners=align_corners) return transformed
def apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(B, C, H, W)`. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). params (Dict[str, torch.Tensor]): - params['degrees']: degree to be applied. flags (Dict[str, torch.Tensor]): - params['interpolation']: Integer tensor. NEAREST = 0, BILINEAR = 1. - params['align_corners']: Boolean tensor. Returns: torch.Tensor: The cropped input """ angles: torch.Tensor = params["degrees"].type_as(input) resample_mode: str = Resample(flags['interpolation'].item()).name.lower() align_corners: bool = cast(bool, flags['align_corners'].item()) transformed: torch.Tensor = rotate(input, angles, mode=resample_mode, align_corners=align_corners) return transformed
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: x = self.dataset[index] if isinstance(x, Sequence): x = x[0] x = x.unsqueeze(0) if self.apply_all: x = x.repeat_interleave(self.target_dim, dim=0) y, angles = self._sample_angles() x = rotate(x, angles) return x, y
def apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor], return_transform: bool = False): r"""Rotate a tensor image or a batch of tensor images a random amount of degrees. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: params (dict): A dict that must have {'degrees': torch.Tensor}. Can be generated from kornia.augmentation.random_generator.random_rotation_generator return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation wont be concatenated """ input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) angles: torch.Tensor = params["degrees"].type_as(input) transformed: torch.Tensor = rotate( input, angles, mode=Resample(params['interpolation'].item()).name.lower()) if return_transform: # TODO: This part should be inferred from rotate directly center: torch.Tensor = _compute_tensor_center(input) rotation_mat: torch.Tensor = _compute_rotation_matrix( angles, center.expand(angles.shape[0], -1)) # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix trans_mat: torch.Tensor = torch.eye(3, device=input.device, dtype=input.dtype).repeat( input.shape[0], 1, 1) trans_mat[:, 0] = rotation_mat[:, 0] trans_mat[:, 1] = rotation_mat[:, 1] return transformed, trans_mat return transformed
def __call__(self, x): batch_size = x.size(0) angles = random.choices(self.angles, k=batch_size) angles = torch.tensor(angles).to(self.device) return rotate(x, angles)
def grad_rot(input, a, b, c): rot = rotate(input, torch.tensor(30.0, dtype=input.dtype, device=device)) return enhance.equalize_clahe(rot, a, b, c)
def _augment(self, data: Tensor) -> Tensor: index = torch.randint(size=(data.size(0), ), low=0, high=4) angles = self._pos_angles[index].squeeze(-1) return rotate(data, angles)