def _apply_color_jitter(input: torch.Tensor, params: Dict[str, torch.Tensor], return_transform: bool = False) -> UnionType: r"""Apply Color Jitter on a tensor image or a batch of tensor images with given random parameters. Input should be a tensor of shape (H, W), (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: params (dict): A dict that must have { 'brightness_factor': torch.Tensor, 'contrast_factor': torch.Tensor, 'hue_factor': torch.Tensor, 'saturation_factor': torch.Tensor, }. Can be generated from kornia.augmentation.param_gen._random_color_jitter_gen return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. Returns: torch.Tensor: The color jitterred input torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag is set to ``True`` """ # TODO: params validation input = _transform_input(input) if not isinstance(return_transform, bool): raise TypeError( f"The return_transform flag must be a bool. Got {type(return_transform)}" ) device: torch.device = input.device dtype: torch.dtype = input.dtype transforms = nn.ModuleList([ AdjustBrightness(params['brightness_factor'].to(device)), AdjustContrast(params['contrast_factor'].to(device)), AdjustSaturation(params['saturation_factor'].to(device)), AdjustHue(params['hue_factor'].to(device)) ]) jittered = input for idx in torch.randperm(4).tolist(): t = transforms[idx] jittered = t(jittered) if return_transform: identity: torch.Tensor = torch.eye(3, device=device, dtype=dtype).repeat( input.shape[0], 1, 1) return jittered, identity return jittered
def color_jitter(input: torch.Tensor, brightness: FloatUnionType = 0., contrast: FloatUnionType = 0., saturation: FloatUnionType = 0., hue: FloatUnionType = 0., return_transform: bool = False) -> UnionType: r"""Random color jiter of an image or batch of images. See :class:`~kornia.augmentation.ColorJitter` for details. """ def _check_and_bound(factor: FloatUnionType, name: str, center: float = 0., bounds: Tuple[float, float] = (0, float('inf')), device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32) -> torch.Tensor: r"""Check inputs and compute the corresponding factor bounds """ if isinstance(factor, float): if factor < 0: raise ValueError(f"If {name} is a single number number, it must be non negative. Got {factor}") factor_bound = torch.tensor([center - factor, center + factor]) factor_bound = torch.clamp(factor_bound, bounds[0], bounds[1]) elif (isinstance(factor, torch.Tensor) and factor.dim() == 0): if factor < 0: raise ValueError(f"If {name} is a single number number, it must be non negative. Got {factor}") factor_bound = torch.tensor([torch.tensor(center) - factor, torch.tensor(center) + factor]) factor_bound = torch.clamp(factor_bound, bounds[0], bounds[1]) elif isinstance(factor, (tuple, list)) and len(factor) == 2: if not bounds[0] <= factor[0] <= factor[1] <= bounds[1]: raise ValueError(f"{name}[0] should be smaller than {name}[1] got {factor}") factor_bound = torch.tensor(factor) elif isinstance(factor, torch.Tensor) and factor.shape[0] == 2 and factor.dim() == 1: if not bounds[0] <= factor[0] <= factor[1] <= bounds[1]: raise ValueError(f"{name}[0] should be smaller than {name}[1] got {factor}") factor_bound = factor else: raise TypeError( f"The {name} should be a float number or a tuple with length 2 whose values move between {bounds}.") return factor_bound.to(device).to(dtype) if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not isinstance(return_transform, bool): raise TypeError(f"The return_transform flag must be a bool. Got {type(return_transform)}") device: torch.device = input.device dtype: torch.dtype = input.dtype brightness_bound: torch.Tensor = _check_and_bound( brightness, 'brightness', bounds=( float('-inf'), float('inf')), device=device, dtype=dtype) contrast_bound: torch.Tensor = _check_and_bound(contrast, 'contrast', center=1., device=device, dtype=dtype) saturation_bound: torch.Tensor = _check_and_bound(saturation, 'saturation', center=1., device=device, dtype=dtype) hue_bound: torch.Tensor = _check_and_bound(hue, 'hue', bounds=(-pi.item(), pi.item()), device=device, dtype=dtype) input = input.unsqueeze(0) input = input.view((-1, (*input.shape[-3:]))) brightness_distribution = Uniform(brightness_bound[0], brightness_bound[1]) brightness_factor = brightness_distribution.rsample([input.shape[0]]) contrast_distribution = Uniform(contrast_bound[0], contrast_bound[1]) contrast_factor = contrast_distribution.rsample([input.shape[0]]) hue_distribution = Uniform(hue_bound[0], hue_bound[1]) hue_factor = hue_distribution.rsample([input.shape[0]]) saturation_distribution = Uniform(saturation_bound[0], saturation_bound[1]) saturation_factor = saturation_distribution.rsample([input.shape[0]]) transforms = nn.ModuleList([AdjustBrightness(brightness_factor), AdjustContrast(contrast_factor), AdjustSaturation(saturation_factor), AdjustHue(hue_factor)]) jittered = input for idx in torch.randperm(4).tolist(): t = transforms[idx] jittered = t(jittered) if return_transform: identity: torch.Tensor = torch.eye(3, device=device, dtype=dtype).repeat(input.shape[0], 1, 1) return jittered, identity return jittered