예제 #1
0
def apply_rotation(input: torch.Tensor,
                   params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Rotate a tensor image or a batch of tensor images a random amount of degrees.
    Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(B, C, H, W)`.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['degrees']: degree to be applied.

    Returns:
        torch.Tensor: The cropped input
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])
    angles: torch.Tensor = params["degrees"].type_as(input)

    resample_mode: str = Resample(params['interpolation'].item()).name.lower()
    align_corners: bool = cast(bool, params['align_corners'].item())

    transformed: torch.Tensor = rotate(input,
                                       angles,
                                       mode=resample_mode,
                                       align_corners=align_corners)

    return transformed
예제 #2
0
def apply_rotation(input: torch.Tensor,
                   params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Rotate a tensor image or a batch of tensor images a random amount of degrees.
    Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        params (dict): A dict that must have {'degrees': torch.Tensor}. Can be generated from
                       kornia.augmentation.random_generator.random_rotation_generator
        return_transform (bool): if ``True`` return the matrix describing the transformation applied to each
                                      input tensor. If ``False`` and the input is a tuple the applied transformation
                                      wont be concatenated
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])
    angles: torch.Tensor = params["degrees"].type_as(input)

    resample_mode: str = Resample(params['interpolation'].item()).name.lower()
    align_corners: bool = cast(bool, params['align_corners'].item())

    transformed: torch.Tensor = rotate(input,
                                       angles,
                                       mode=resample_mode,
                                       align_corners=align_corners)

    return transformed
예제 #3
0
def apply_rotation(input: torch.Tensor, params: Dict[str, torch.Tensor],
                   flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Rotate a tensor image or a batch of tensor images a random amount of degrees.

    Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(B, C, H, W)`.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['degrees']: degree to be applied.
        flags (Dict[str, torch.Tensor]):
            - params['interpolation']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: The cropped input
    """
    angles: torch.Tensor = params["degrees"].type_as(input)

    resample_mode: str = Resample(flags['interpolation'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    transformed: torch.Tensor = rotate(input,
                                       angles,
                                       mode=resample_mode,
                                       align_corners=align_corners)

    return transformed
    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
        x = self.dataset[index]
        if isinstance(x, Sequence):
            x = x[0]
        x = x.unsqueeze(0)
        if self.apply_all:
            x = x.repeat_interleave(self.target_dim, dim=0)
        y, angles = self._sample_angles()
        x = rotate(x, angles)

        return x, y
예제 #5
0
def apply_rotation(input: torch.Tensor,
                   params: Dict[str, torch.Tensor],
                   return_transform: bool = False):
    r"""Rotate a tensor image or a batch of tensor images a random amount of degrees.
    Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.

    Args:
        params (dict): A dict that must have {'degrees': torch.Tensor}. Can be generated from
                       kornia.augmentation.random_generator.random_rotation_generator
        return_transform (bool): if ``True`` return the matrix describing the transformation applied to each
                                      input tensor. If ``False`` and the input is a tuple the applied transformation
                                      wont be concatenated
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])
    angles: torch.Tensor = params["degrees"].type_as(input)

    transformed: torch.Tensor = rotate(
        input,
        angles,
        mode=Resample(params['interpolation'].item()).name.lower())

    if return_transform:
        # TODO: This part should be inferred from rotate directly
        center: torch.Tensor = _compute_tensor_center(input)
        rotation_mat: torch.Tensor = _compute_rotation_matrix(
            angles, center.expand(angles.shape[0], -1))

        # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix
        trans_mat: torch.Tensor = torch.eye(3,
                                            device=input.device,
                                            dtype=input.dtype).repeat(
                                                input.shape[0], 1, 1)
        trans_mat[:, 0] = rotation_mat[:, 0]
        trans_mat[:, 1] = rotation_mat[:, 1]

        return transformed, trans_mat

    return transformed
 def __call__(self, x):
     batch_size = x.size(0)
     angles = random.choices(self.angles, k=batch_size)
     angles = torch.tensor(angles).to(self.device)
     return rotate(x, angles)
예제 #7
0
 def grad_rot(input, a, b, c):
     rot = rotate(input,
                  torch.tensor(30.0, dtype=input.dtype, device=device))
     return enhance.equalize_clahe(rot, a, b, c)
예제 #8
0
 def _augment(self, data: Tensor) -> Tensor:
     index = torch.randint(size=(data.size(0), ), low=0, high=4)
     angles = self._pos_angles[index].squeeze(-1)
     return rotate(data, angles)