def apply_hflip(input: torch.Tensor, params: Dict[str, torch.Tensor], return_transform: bool = False) -> UnionType: r"""Apply Horizontally flip on a tensor image or a batch of tensor images with given random parameters. Input should be a tensor of shape (H, W), (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: params (dict): A dict that must have {'batch_prob': torch.Tensor}. Can be generated from kornia.augmentation.param_gen._random_prob_gen. return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. Returns: torch.Tensor: The horizontally flipped input torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag is set to ``True`` """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if len(input.shape) == 2: input = input.unsqueeze(0) if len(input.shape) == 3: input = input.unsqueeze(0) if len(input.shape) != 4: raise ValueError( f"Input size must have a shape of (*, C, H, W). Got {input.shape}") if not isinstance(return_transform, bool): raise TypeError( f"The return_transform flag must be a bool. Got {type(return_transform)}" ) device: torch.device = input.device dtype: torch.dtype = input.dtype flipped: torch.Tensor = input.clone() to_flip = params['batch_prob'].to(device) flipped[to_flip] = hflip(input[to_flip]) flipped.squeeze_() if return_transform: trans_mat: torch.Tensor = torch.eye(3, device=device, dtype=dtype).repeat( input.shape[0], 1, 1) w: int = input.shape[-1] flip_mat: torch.Tensor = torch.tensor([[-1, 0, w], [0, 1, 0], [0, 0, 1]]) trans_mat[to_flip] = flip_mat.to(device).to(dtype) return flipped, trans_mat return flipped
def random_hflip(input: torch.Tensor, p: float = 0.5, return_transform: bool = False) -> UnionType: r"""Horizontally flip a tensor image or a batch of tensor images randomly with a given probability. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: p (float): probability of the image being flipped. Default value is 0.5 return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. Returns: torch.Tensor: The horizontally flipped input torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag is set to ``True`` """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not isinstance(p, float): raise TypeError( f"The probability should be a float number. Got {type(p)}") if not isinstance(return_transform, bool): raise TypeError( f"The return_transform flag must be a bool. Got {type(return_transform)}" ) device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.unsqueeze(0) input = input.view((-1, (*input.shape[-3:]))) probs: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(0, 1) to_flip: torch.Tensor = probs < p flipped: torch.Tensor = input.clone() flipped[to_flip] = hflip(input[to_flip]) flipped.squeeze_() if return_transform: trans_mat: torch.Tensor = torch.eye(3, device=device, dtype=dtype).expand( input.shape[0], -1, -1) w: int = input.shape[-2] flip_mat: torch.Tensor = torch.tensor([[-1, 0, w], [0, 1, 0], [0, 0, 1]]) trans_mat[to_flip] = flip_mat.to(device).to(dtype) return flipped, trans_mat return flipped