def apply_adjust_gamma(input: torch.Tensor, params: Dict[str, torch.Tensor], return_transform: bool = False): r"""Perform gamma correction on an image. The input image is expected to be in the range of [0, 1]. Args: input (torch.Tensor): Image/Tensor to be adjusted in the shape of (\*, N). gamma (float): Non negative real number, same as γ\gammaγ in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. gain (float, optional): The constant multiplier. Default 1. Returns: torch.Tensor: Adjusted image. """ input = _transform_input(input) _validate_input_dtype( input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) transformed = adjust_gamma(input, params['gamma_factor'].to(input.dtype)) if return_transform: identity: torch.Tensor = torch.eye(3, device=input.device, dtype=input.dtype).repeat( input.shape[0], 1, 1) return transformed, identity return transformed
def apply_adjust_gamma(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Perform gamma correction on an image. Args: input (torch.Tensor): Image/Tensor to be adjusted in the shape of (\*, N). params (Dict[str, torch.Tensor]): - params['gamma_factor']: Non negative real number, same as γ\gammaγ in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. Returns: torch.Tensor: Adjusted image. """ input = _transform_input(input) _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) transformed = adjust_gamma(input, params['gamma_factor'].to(input.dtype)) return transformed