예제 #1
0
def apply_adjust_gamma(input: torch.Tensor,
                       params: Dict[str, torch.Tensor],
                       return_transform: bool = False):
    r"""Perform gamma correction on an image.

    The input image is expected to be in the range of [0, 1].

    Args:
        input (torch.Tensor): Image/Tensor to be adjusted in the shape of (\*, N).
        gamma (float): Non negative real number, same as γ\gammaγ in the equation.
          gamma larger than 1 make the shadows darker, while gamma smaller than 1 make
          dark regions lighter.
        gain (float, optional): The constant multiplier. Default 1.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_gamma(input, params['gamma_factor'].to(input.dtype))

    if return_transform:
        identity: torch.Tensor = torch.eye(3,
                                           device=input.device,
                                           dtype=input.dtype).repeat(
                                               input.shape[0], 1, 1)
        return transformed, identity

    return transformed
예제 #2
0
파일: functional.py 프로젝트: rdevon/kornia
def apply_adjust_gamma(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Perform gamma correction on an image.

    Args:
        input (torch.Tensor): Image/Tensor to be adjusted in the shape of (\*, N).
        params (Dict[str, torch.Tensor]):
            - params['gamma_factor']: Non negative real number, same as γ\gammaγ in the equation.
            gamma larger than 1 make the shadows darker, while gamma smaller than 1 make
            dark regions lighter.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_gamma(input, params['gamma_factor'].to(input.dtype))

    return transformed