示例#1
0
def apply_adjust_hue(input: torch.Tensor,
                     params: Dict[str, torch.Tensor],
                     return_transform: bool = False):
    """Wrapper for adjust_hue for Torchvision-like param settings.

    Args:
        input (torch.Tensor): Image/Tensor to be adjusted in the shape of (*, N).
        hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5
          and -0.5 give complete reversal of hue channel in HSV space in positive and negative
          direction respectively. 0 means no shift. Therefore, both -PI and PI will give an
          image with complementary colors while 0 gives the original image.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_hue(input,
                             params['hue_factor'].to(input.dtype) * 2 * pi)

    if return_transform:
        identity: torch.Tensor = torch.eye(3,
                                           device=input.device,
                                           dtype=input.dtype).repeat(
                                               input.shape[0], 1, 1)
        return transformed, identity

    return transformed
示例#2
0
def apply_adjust_hue(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Wrapper for adjust_hue for Torchvision-like param settings.

    Args:
        input (torch.Tensor): Image/Tensor to be adjusted in the shape of (*, N).
        params (Dict[str, torch.Tensor]):
            - params['hue_factor']: How much to shift the hue channel. Should be in [-0.5, 0.5].
            0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative
            direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an
            image with complementary colors while 0 gives the original image.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_hue(input, params['hue_factor'].to(input.dtype) * 2 * pi)

    return transformed