示例#1
0
def apply_adjust_brightness(input: torch.Tensor,
                            params: Dict[str, torch.Tensor],
                            return_transform: bool = False) -> UnionType:
    """ Wrapper for adjust_brightness for Torchvision-like param settings.

    Args:
        input (torch.Tensor): Image/Input to be adjusted in the shape of (*, N).
        brightness_factor (Union[float, torch.Tensor]): Brightness adjust factor per element
          in the batch. 0 gives a black image, 1 does not modify the input image and 2 gives a
          white image, while any other number modify the brightness.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_brightness(
        input, params['brightness_factor'].to(input.dtype) - 1)

    if return_transform:
        identity: torch.Tensor = torch.eye(3,
                                           device=input.device,
                                           dtype=input.dtype).repeat(
                                               input.shape[0], 1, 1)
        return transformed, identity

    return transformed
示例#2
0
def apply_adjust_brightness(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
    """ Wrapper for adjust_brightness for Torchvision-like param settings.

    Args:
        input (torch.Tensor): Image/Input to be adjusted in the shape of (*, N).
        params (Dict[str, torch.Tensor]):
            - params['brightness_factor']: Brightness adjust factor per element
            in the batch. 0 gives a black image, 1 does not modify the input image and 2 gives a
            white image, while any other number modify the brightness.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_brightness(input, params['brightness_factor'].to(input.dtype))

    return transformed