def apply_transform( self, input: Tensor, params: Dict[str, Tensor], transform: Optional[Tensor] = None ) -> Tensor: transforms = [ lambda img: adjust_brightness(img, params["brightness_factor"] - 1), lambda img: adjust_contrast(img, params["contrast_factor"]), lambda img: adjust_saturation(img, params["saturation_factor"]), lambda img: adjust_hue(img, params["hue_factor"] * 2 * pi), ] jittered = input for idx in params["order"].tolist(): t = transforms[idx] jittered = t(jittered) return jittered
def apply_adjust_saturation(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: """Apply saturation adjustment. Wrapper for adjust_saturation for Torchvision-like param settings. Args: input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, H, W)`. params (Dict[str, torch.Tensor]): - params['saturation_factor']: How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: torch.Tensor: Adjusted image with shape :math:`(B, C, H, W)`. """ transformed = adjust_saturation(input, params['saturation_factor'].to(input.dtype)) return transformed
def apply_adjust_saturation(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: """Wrapper for adjust_saturation for Torchvision-like param settings. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). params (Dict[str, torch.Tensor]): - params['saturation_factor']: How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: torch.Tensor: Adjusted image. """ input = _transform_input(input) _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) transformed = adjust_saturation(input, params['saturation_factor'].to(input.dtype)) return transformed
def color(x: Tensor, v: float) -> Tensor: return E.adjust_saturation(x, v)