Пример #1
0
    def apply_transform(
        self, input: Tensor, params: Dict[str, Tensor], transform: Optional[Tensor] = None
    ) -> Tensor:

        transforms = [
            lambda img: adjust_brightness(img, params["brightness_factor"] - 1),
            lambda img: adjust_contrast(img, params["contrast_factor"]),
            lambda img: adjust_saturation(img, params["saturation_factor"]),
            lambda img: adjust_hue(img, params["hue_factor"] * 2 * pi),
        ]

        jittered = input
        for idx in params["order"].tolist():
            t = transforms[idx]
            jittered = t(jittered)

        return jittered
Пример #2
0
def apply_adjust_saturation(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Apply saturation adjustment.

    Wrapper for adjust_saturation for Torchvision-like param settings.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, H, W)`.
        params (Dict[str, torch.Tensor]):
            - params['saturation_factor']:  How much to adjust the saturation. 0 will give a black
              and white image, 1 will give the original image while 2 will enhance the saturation
              by a factor of 2.

    Returns:
        torch.Tensor: Adjusted image with shape :math:`(B, C, H, W)`.
    """
    transformed = adjust_saturation(input, params['saturation_factor'].to(input.dtype))

    return transformed
Пример #3
0
def apply_adjust_saturation(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Wrapper for adjust_saturation for Torchvision-like param settings.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['saturation_factor']:  How much to adjust the saturation. 0 will give a black
              and white image, 1 will give the original image while 2 will enhance the saturation
              by a factor of 2.

    Returns:
        torch.Tensor: Adjusted image.
    """
    input = _transform_input(input)
    _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    transformed = adjust_saturation(input, params['saturation_factor'].to(input.dtype))

    return transformed
Пример #4
0
def color(x: Tensor, v: float) -> Tensor:
    return E.adjust_saturation(x, v)