Beispiel #1
0
    def forward(self, prediction: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        r"""Computation of Content loss between feature representations of prediction and target tensors.
        Args:
            prediction: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
            target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        """
        _validate_input(input_tensors=(prediction, target),
                        allow_5d=False,
                        allow_negative=True)
        prediction, target = _adjust_dimensions(input_tensors=(prediction,
                                                               target))

        self.model.to(prediction)
        prediction_features = self.get_features(prediction)
        target_features = self.get_features(target)

        distances = self.compute_distance(prediction_features, target_features)

        # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
        loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3])
                          for d, w in zip(distances, self.weights)],
                         dim=1).sum(dim=1)

        if self.reduction == 'none':
            return loss

        return {'mean': loss.mean, 'sum': loss.sum}[self.reduction](dim=0)
Beispiel #2
0
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Forward pass a batch of square patches with shape  (N, C, FEATURES, FEATURES)

        Returns:
            features: Concatenation of model features from different scales
            x11: Outputs of the last convolutional layer used as weights
        """
        _validate_input(input_tensors=x, allow_5d=False, allow_negative=False)
        x = _adjust_dimensions(input_tensors=x)
        assert x.shape[2] == x.shape[3] == self.FEATURES, \
            f"Expected square input with shape {self.FEATURES, self.FEATURES}, got {x.shape}"

        # conv1 -> relu -> conv2 -> relu -> pool -> conv3 -> relu
        x3 = F.relu(self.conv3(self.pool(F.relu(self.conv2(F.relu(self.conv1(x)))))))
        # conv4 -> relu -> pool -> conv5 -> relu
        x5 = F.relu(self.conv5(self.pool(F.relu(self.conv4(x3)))))
        # conv6 -> relu -> pool -> conv7 -> relu
        x7 = F.relu(self.conv7(self.pool(F.relu(self.conv6(x5)))))
        # conv8 -> relu -> pool -> conv9 -> relu
        x9 = F.relu(self.conv9(self.pool(F.relu(self.conv8(x7)))))
        # conv10 -> relu -> pool1-> conv11 -> relU
        x11 = self.flatten(F.relu(self.conv11(self.pool(F.relu(self.conv10(x9))))))
        # flatten and concatenate
        features = torch.cat((self.flatten(x3), self.flatten(x5), self.flatten(x7), self.flatten(x9), x11), dim=1)
        return features, x11
Beispiel #3
0
def gmsd(
    x: torch.Tensor,
    y: torch.Tensor,
    reduction: str = 'mean',
    data_range: Union[int, float] = 1.,
    t: float = 170 / (255.**2)
) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation.

    Inputs supposed to be in range [0, data_range] with RGB channels order for colour images.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        t: Constant from the reference paper numerical stability of similarity map.

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013)
        https://arxiv.org/pdf/1308.3052.pdf
    """

    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    scale_weights=None,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Rescale
    x = x / data_range
    y = y / data_range

    num_channels = x.size(1)
    if num_channels == 3:
        x = rgb2yiq(x)[:, :1]
        y = rgb2yiq(y)[:, :1]
    up_pad = 0
    down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
    pad_to_use = [up_pad, down_pad, up_pad, down_pad]
    x = F.pad(x, pad=pad_to_use)
    y = F.pad(y, pad=pad_to_use)

    x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
    y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0)

    score = _gmsd(x=x, y=y, t=t)
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #4
0
def brisque(x: torch.Tensor,
            kernel_size: int = 7,
            kernel_sigma: float = 7 / 6,
            data_range: Union[int, float] = 1.,
            reduction: str = 'mean',
            interpolation: str = 'nearest') -> torch.Tensor:
    r"""Interface of BRISQUE index.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). RGB channel order for colour images.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Maximum value range of input images (usually 1.0 or 255).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none".
        interpolation: Interpolation to be used for scaling.

    Returns:
        Value of BRISQUE index.

    Note:
        The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation.
        Update the torch and torchvision to the latest versions.

    References:
        .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
        https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
    """
    if '1.5.0' in torch.__version__:
        warnings.warn(
            f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.'
            f'Update torch to the latest version to access full functionality of the BRIQSUE.'
            f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and'
            f'https://github.com/pytorch/pytorch/issues/38869.')

    _validate_input(input_tensors=x, allow_5d=False, kernel_size=kernel_size)
    x = _adjust_dimensions(input_tensors=x)

    assert data_range >= x.max(
    ), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.'
    x = x * 255. / data_range

    if x.size(1) == 3:
        x = rgb2yiq(x)[:, :1]
    features = []
    num_of_scales = 2
    for _ in range(num_of_scales):
        features.append(_natural_scene_statistics(x, kernel_size,
                                                  kernel_sigma))
        x = F.interpolate(x,
                          size=(x.size(2) // 2, x.size(3) // 2),
                          mode=interpolation)

    features = torch.cat(features, dim=-1)
    scaled_features = _scale_features(features)
    score = _score_svr(scaled_features)
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #5
0
def gmsd(
    prediction: torch.Tensor,
    target: torch.Tensor,
    reduction: Optional[str] = 'mean',
    data_range: Union[int, float] = 1.,
    t: float = 170 / (255.**2)
) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation
    Both inputs supposed to be in range [0, 1] with RGB order.
    Args:
        prediction: Tensor of shape :math:`(N, C, H, W)` holding an distorted image.
        target: Tensor of shape :math:`(N, C, H, W)` holding an target image
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        t: Constant from the reference paper numerical stability of similarity map.

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        https://arxiv.org/pdf/1308.3052.pdf
    """

    _validate_input(input_tensors=(prediction, target),
                    allow_5d=False,
                    scale_weights=None)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    prediction = prediction / float(data_range)
    target = target / float(data_range)

    num_channels = prediction.size(1)
    if num_channels == 3:
        prediction = rgb2yiq(prediction)[:, :1]
        target = rgb2yiq(target)[:, :1]
    up_pad = 0
    down_pad = max(prediction.shape[2] % 2, prediction.shape[3] % 2)
    pad_to_use = [up_pad, down_pad, up_pad, down_pad]
    prediction = F.pad(prediction, pad=pad_to_use)
    target = F.pad(target, pad=pad_to_use)

    prediction = F.avg_pool2d(prediction, kernel_size=2, stride=2, padding=0)
    target = F.avg_pool2d(target, kernel_size=2, stride=2, padding=0)

    score = _gmsd(prediction=prediction, target=target, t=t)
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #6
0
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
         data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False,
         k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Interface of Structural Similarity (SSIM) index.

    Args:
        x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
        y: Batch of images. Required to be 2D (H, W), 3D (C,H,W) 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Value range of input images (usually 1.0 or 255).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        full: Return cs map or not.
        k1: Algorithm parameter, K1 (small constant, see [1]).
        k2: Algorithm parameter, K2 (small constant, see [1]).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
        as a tensor of size 2.

    References:
        .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           :DOI:`10.1109/TIP.2003.819861`
    """
    _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=None)
    x, y = _adjust_dimensions(input_tensors=(x, y))
    if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor):
        x = x.type(torch.float32)
        y = y.type(torch.float32)

    kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
    _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
    ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2)
    ssim_val = ssim_map.mean(1)
    cs = cs_map.mean(1)

    if reduction != 'none':
        reduction_operation = {'mean': torch.mean,
                               'sum': torch.sum}
        ssim_val = reduction_operation[reduction](ssim_val, dim=0)
        cs = reduction_operation[reduction](cs, dim=0)

    if full:
        return ssim_val, cs

    return ssim_val
Beispiel #7
0
def total_variation(x: torch.Tensor,
                    reduction: str = 'mean',
                    norm_type: str = 'l2') -> torch.Tensor:
    r"""Compute Total Variation metric

    Args:
        x: Tensor with shape (N, C, H, W).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic  or anisotropic.

    Returns:
        score : Total variation of a given tensor

    References:
        https://www.wikiwand.com/en/Total_variation_denoising
        https://remi.flamary.com/demos/proxtv.html
    """
    _validate_input(x, allow_5d=False)
    x = _adjust_dimensions(x)

    if norm_type == 'l1':
        w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]),
                               dim=[1, 2, 3])
        h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]),
                               dim=[1, 2, 3])
        score = (h_variance + w_variance)
    elif norm_type == 'l2':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2),
                               dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2),
                               dim=[1, 2, 3])
        score = torch.sqrt(h_variance + w_variance)
    elif norm_type == 'l2_squared':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2),
                               dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2),
                               dim=[1, 2, 3])
        score = (h_variance + w_variance)
    else:
        raise ValueError(
            "Incorrect reduction type, should be one of {'l1', 'l2', 'l2_squared'}"
        )

    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #8
0
def psnr(x: torch.Tensor,
         y: torch.Tensor,
         data_range: Union[int, float] = 1.0,
         reduction: str = 'mean',
         convert_to_greyscale: bool = False) -> torch.Tensor:
    r"""Compute Peak Signal-to-Noise Ratio for a batch of images.
    Supports both greyscale and color images with RGB channel order.

    Args:
        x: Predicted images set :math:`x`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        y: Target images set :math:`y`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        convert_to_greyscale: Convert RGB image to YCbCr format and computes PSNR
            only on luminance channel if `True`. Compute on all 3 channels otherwise.

    Returns:
        PSNR: Index of similarity betwen two images.

    References:
        https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
    """
    _validate_input((x, y), allow_5d=False, data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Constant for numerical stability
    EPS = 1e-8

    x = x / data_range
    y = y / data_range

    if (x.size(1) == 3) and convert_to_greyscale:
        # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
        rgb_to_grey = torch.tensor([0.299, 0.587, 0.114]).view(1, -1, 1,
                                                               1).to(x)
        x = torch.sum(x * rgb_to_grey, dim=1, keepdim=True)
        y = torch.sum(y * rgb_to_grey, dim=1, keepdim=True)

    mse = torch.mean((x - y)**2, dim=[1, 2, 3])
    score: torch.Tensor = -10 * torch.log10(mse + EPS)

    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #9
0
    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        r"""
        Computation of PieAPP  between feature representations of prediction and target tensors.

        Args:
            prediction: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
            target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        """
        _validate_input(
            input_tensors=(prediction, target), allow_5d=False, allow_negative=True, data_range=self.data_range)
        prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

        N, C, _, _ = prediction.shape
        if C == 1:
            prediction = prediction.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
            warnings.warn('The original PieAPP supports only RGB images.'
                          'The input images were converted to RGB by copying the grey channel 3 times.')

        self.model.to(device=prediction.device)
        prediction_features, prediction_weights = self.get_features(prediction)
        target_features, target_weights = self.get_features(target)

        distances, weights = self.model.compute_difference(
            target_features - prediction_features,
            target_weights - prediction_weights
        )

        distances = distances.reshape(N, -1)
        weights = weights.reshape(N, -1)

        # Scale scores, then average across patches
        loss = torch.stack([(d * w).sum() / w.sum() for d, w in zip(distances, weights)])

        if self.reduction == 'none':
            return loss

        return {'mean': loss.mean,
                'sum': loss.sum
                }[self.reduction](dim=0)
Beispiel #10
0
def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
                     data_range: Union[int, float] = 1., reduction: str = 'mean',
                     scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None,
                     k1: float = 0.01, k2: float = 0.03) -> torch.Tensor:
    r""" Interface of Multi-scale Structural Similarity (MS-SSIM) index.

    Args:
        x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
            The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1.
        y: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
            The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Value range of input images (usually 1.0 or 255).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        scale_weights: Weights for different scales.
            If None, default weights from the paper [1] will be used.
            Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333).
        k1: Algorithm parameter, K1 (small constant, see [2]).
        k2: Algorithm parameter, K2 (small constant, see [2]).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Multi-scale Structural Similarity (MS-SSIM) index. In case of 5D input tensors,
        complex value is returned as a tensor of size 2.

    References:
        .. [1] Wang, Z., Simoncelli, E. P., Bovik, A. C. (2003).
           Multi-scale Structural Similarity for Image Quality Assessment.
           IEEE Asilomar Conference on Signals, Systems and Computers, 37,
           https://ieeexplore.ieee.org/document/1292216
           :DOI:`10.1109/ACSSC.2003.1292216`
        .. [2] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           :DOI:`10.1109/TIP.2003.819861`
    """
    _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=scale_weights)
    x, y = _adjust_dimensions(input_tensors=(x, y))
    if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor):
        x = x.type(torch.float32)
        y = y.type(torch.float32)

    if scale_weights is None:
        scale_weights_from_ms_ssim_paper = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
        scale_weights = scale_weights_from_ms_ssim_paper

    scale_weights_tensor = scale_weights if isinstance(scale_weights, torch.Tensor) else torch.tensor(scale_weights)
    scale_weights_tensor = scale_weights_tensor.to(y)
    kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
    
    _compute_msssim = _multi_scale_ssim_complex if x.dim() == 5 else _multi_scale_ssim
    msssim_val = _compute_msssim(
        x=x,
        y=y,
        data_range=data_range,
        kernel=kernel,
        scale_weights_tensor=scale_weights_tensor,
        k1=k1,
        k2=k2
    )

    if reduction == 'none':
        return msssim_val

    return {'mean': torch.mean,
            'sum': torch.sum}[reduction](msssim_val, dim=0)
Beispiel #11
0
def multi_scale_gmsd(x: torch.Tensor,
                     y: torch.Tensor,
                     data_range: Union[int, float] = 1.,
                     reduction: str = 'mean',
                     scale_weights: Optional[Union[torch.Tensor, Tuple[float,
                                                                       ...],
                                                   List[float]]] = None,
                     chromatic: bool = False,
                     alpha: float = 0.5,
                     beta1: float = 0.01,
                     beta2: float = 0.32,
                     beta3: float = 15.,
                     t: float = 170) -> torch.Tensor:
    r"""Computation of Multi scale GMSD.

    Inputs supposed to be in range [0, data_range] with RGB channels order for colour images.
    The height and width should be at least 2 ** scales + 1.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
        scale_weights: Weights for different scales. Can contain any number of floating point values.
        chromatic: Flag to use MS-GMSDc algorithm from paper.
            It also evaluates chromatic components of the image. Default: True
        alpha: Masking coefficient. See [1] for details.
        beta1: Algorithm parameter. Weight of chromatic component in the loss.
        beta2: Algorithm parameter. Small constant, see [1].
        beta3: Algorithm parameter. Small constant, see [1].
        t: Constant from the reference paper numerical stability of similarity map

    Returns:
        Value of MS-GMSD. 0 <= GMSD loss <= 1.
    """
    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    scale_weights=scale_weights,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Rescale
    x = x / data_range * 255
    y = y / data_range * 255

    # Values from the paper
    if scale_weights is None:
        scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019])
    else:
        # Normalize scale weights
        scale_weights = torch.tensor(scale_weights) / torch.tensor(
            scale_weights).sum()

    scale_weights = cast(torch.Tensor, scale_weights).to(x)

    # Check that input is big enough
    num_scales = scale_weights.size(0)
    min_size = 2**num_scales + 1

    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(
            f'Invalid size of the input images, expected at least {min_size}x{min_size}.'
        )

    num_channels = x.size(1)
    if num_channels == 3:
        x = rgb2yiq(x)
        y = rgb2yiq(y)

    ms_gmds = []
    for scale in range(num_scales):
        if scale > 0:

            # Average by 2x2 filter and downsample
            up_pad = 0
            down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
            pad_to_use = [up_pad, down_pad, up_pad, down_pad]
            x = F.pad(x, pad=pad_to_use)
            y = F.pad(y, pad=pad_to_use)
            x = F.avg_pool2d(x, kernel_size=2, padding=0)
            y = F.avg_pool2d(y, kernel_size=2, padding=0)

        score = _gmsd(x[:, :1], y[:, :1], t=t, alpha=alpha)
        ms_gmds.append(score)

    # Stack results in different scales and multiply by weight
    ms_gmds_val = scale_weights.view(1, num_scales) * (torch.stack(ms_gmds,
                                                                   dim=1)**2)

    # Sum and take sqrt per-image
    ms_gmds_val = torch.sqrt(torch.sum(ms_gmds_val, dim=1))

    # Shape: (batch_size, )
    score = ms_gmds_val

    if chromatic:
        assert x.size(
            1) == 3, "Chromatic component can be computed only for RGB images!"

        x_iq = x[:, 1:]
        y_iq = y[:, 1:]

        rmse_iq = torch.sqrt(torch.mean((x_iq - y_iq)**2, dim=[2, 3]))
        rmse_chrome = torch.sqrt(torch.sum(rmse_iq**2, dim=1))
        gamma = 2 / (1 + beta2 * torch.exp(-beta3 * ms_gmds_val)) - 1

        score = gamma * ms_gmds_val + (1 - gamma) * beta1 * rmse_chrome

    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #12
0
def vif_p(prediction: torch.Tensor,
          target: torch.Tensor,
          sigma_n_sq: float = 2.0,
          data_range: Union[int, float] = 1.0,
          reduction: str = 'mean') -> torch.Tensor:
    r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images.
    This metric isn't symmetric, so make sure to place arguments in correct order.

    Both inputs supposed to have RGB order.
    Args:
        prediction: Batch of predicted images with shape (batch_size x channels x H x W)
        target: Batch of target images with shape  (batch_size x channels x H x W)
        sigma_n_sq: HVS model parameter (variance of the visual noise).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        
    Returns:
        VIF: Index of similarity betwen two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted images with higher contrast than original one.
    Note:
        In original paper this method was used for bands in discrete wavelet decomposition.
        Later on authors released code to compute VIF approximation in pixel domain.
        See https://live.ece.utexas.edu/research/Quality/VIF.htm for details.
        
    """
    _validate_input((prediction, target), allow_5d=False)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    min_size = 41
    if prediction.size(-1) < min_size or prediction.size(-2) < min_size:
        raise ValueError(
            f'Invalid size of the input images, expected at least {min_size}x{min_size}.'
        )

    if data_range == 255:
        prediction = prediction / 255.
        target = target / 255.

    # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
    num_channels = prediction.size(1)
    if num_channels == 3:
        prediction = 0.299 * prediction[:,
                                        0, :, :] + 0.587 * prediction[:,
                                                                      1, :, :] + 0.114 * prediction[:,
                                                                                                    2, :, :]
        target = 0.299 * target[:,
                                0, :, :] + 0.587 * target[:,
                                                          1, :, :] + 0.114 * target[:,
                                                                                    2, :, :]

        # Add channel dimension
        prediction = prediction[:, None, :, :]
        target = target[:, None, :, :]

    # Constant for numerical stability
    EPS = 1e-8

    # Progressively downsample images and compute VIF on different scales
    prediction_vif, target_vif = 0, 0
    for scale in range(1, 5):
        kernel_size = 2**(5 - scale) + 1
        kernel = _gaussian_kernel2d(kernel_size, sigma=kernel_size / 5)
        kernel = kernel.view(1, 1, kernel_size, kernel_size).to(prediction)

        if scale > 1:
            # Convolve and downsample
            prediction = F.conv2d(prediction,
                                  kernel)[:, :, ::2, ::2]  # valid padding
            target = F.conv2d(target, kernel)[:, :, ::2, ::2]  # valid padding

        mu_trgt, mu_pred = F.conv2d(target,
                                    kernel), F.conv2d(prediction,
                                                      kernel)  # valid padding
        mu_trgt_sq, mu_pred_sq, mu_trgt_pred = mu_trgt * mu_trgt, mu_pred * mu_pred, mu_trgt * mu_pred

        sigma_trgt_sq = F.conv2d(target**2, kernel) - mu_trgt_sq
        sigma_pred_sq = F.conv2d(prediction**2, kernel) - mu_pred_sq
        sigma_trgt_pred = F.conv2d(target * prediction, kernel) - mu_trgt_pred

        # Zero small negative values
        sigma_trgt_sq = torch.relu(sigma_trgt_sq)
        sigma_pred_sq = torch.relu(sigma_pred_sq)

        g = sigma_trgt_pred / (sigma_trgt_sq + EPS)
        sigma_v_sq = sigma_pred_sq - g * sigma_trgt_pred

        g = torch.where(sigma_trgt_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_trgt_sq >= EPS, sigma_v_sq,
                                 sigma_pred_sq)
        sigma_trgt_sq = torch.where(sigma_trgt_sq >= EPS, sigma_trgt_sq,
                                    torch.zeros_like(sigma_trgt_sq))

        g = torch.where(sigma_pred_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_pred_sq >= EPS, sigma_v_sq,
                                 torch.zeros_like(sigma_v_sq))

        sigma_v_sq = torch.where(g >= 0, sigma_v_sq, sigma_pred_sq)
        g = torch.relu(g)

        sigma_v_sq = torch.where(sigma_v_sq > EPS, sigma_v_sq,
                                 torch.ones_like(sigma_v_sq) * EPS)

        pred_vif_scale = torch.log10(1.0 + (g**2.) * sigma_trgt_sq /
                                     (sigma_v_sq + sigma_n_sq))
        prediction_vif = prediction_vif + torch.sum(pred_vif_scale,
                                                    dim=[1, 2, 3])
        target_vif = target_vif + torch.sum(
            torch.log10(1.0 + sigma_trgt_sq / sigma_n_sq), dim=[1, 2, 3])

    score = (prediction_vif + EPS) / (target_vif + EPS)

    # Reduce if needed
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #13
0
def fsim(x: torch.Tensor,
         y: torch.Tensor,
         reduction: str = 'mean',
         data_range: Union[int, float] = 1.0,
         chromatic: bool = True,
         scales: int = 4,
         orientations: int = 4,
         min_length: int = 6,
         mult: int = 2,
         sigma_f: float = 0.55,
         delta_theta: float = 1.2,
         k: float = 2.0) -> torch.Tensor:
    r"""Compute Feature Similarity Index Measure for a batch of images.

    Args:
        x: Predicted images set :math:`x`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        y: Target images set :math:`y`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        chromatic: Flag to compute FSIMc, which also takes into account chromatic components
        scales: Number of wavelets used for computation of phase congruensy maps
        orientations: Number of filter orientations used for computation of phase congruensy maps
        min_length: Wavelength of smallest scale filter
        mult: Scaling factor between successive filters
        sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's
            transfer function in the frequency domain to the filter center frequency.
        delta_theta: Ratio of angular interval between filter orientations and the standard deviation
            of the angular Gaussian function used to construct filters in the frequency plane.
        k: No of standard deviations of the noise energy beyond the mean at which we set the noise
            threshold  point, below which phase congruency values get penalized.
        
    Returns:
        FSIM: Index of similarity betwen two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted (x) images with higher contrast than the original ones.
    Note:
        This implementation is based on the original MATLAB code.
        https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm
        
    """

    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = x / data_range * 255
    y = y / data_range * 255

    # Apply average pooling
    kernel_size = max(1, round(min(x.shape[-2:]) / 256))
    x = torch.nn.functional.avg_pool2d(x, kernel_size)
    y = torch.nn.functional.avg_pool2d(y, kernel_size)

    num_channels = x.size(1)

    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)

        x_lum = x_yiq[:, :1]
        y_lum = y_yiq[:, :1]

        x_i = x_yiq[:, 1:2]
        y_i = y_yiq[:, 1:2]
        x_q = x_yiq[:, 2:]
        y_q = y_yiq[:, 2:]

    else:
        x_lum = x
        y_lum = y

    # Compute phase congruency maps
    pc_x = _phase_congruency(x_lum,
                             scales=scales,
                             orientations=orientations,
                             min_length=min_length,
                             mult=mult,
                             sigma_f=sigma_f,
                             delta_theta=delta_theta,
                             k=k)
    pc_y = _phase_congruency(y_lum,
                             scales=scales,
                             orientations=orientations,
                             min_length=min_length,
                             mult=mult,
                             sigma_f=sigma_f,
                             delta_theta=delta_theta,
                             k=k)

    # Gradient maps
    kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
    grad_map_x = gradient_map(x_lum, kernels)
    grad_map_y = gradient_map(y_lum, kernels)

    # Constants from the paper
    T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03

    # Compute FSIM
    PC = similarity_map(pc_x, pc_y, T1)
    GM = similarity_map(grad_map_x, grad_map_y, T2)
    pc_max = torch.where(pc_x > pc_y, pc_x, pc_y)
    score = GM * PC * pc_max

    if chromatic:
        assert num_channels == 3, "Chromatic component can be computed only for RGB images!"
        S_I = similarity_map(x_i, y_i, T3)
        S_Q = similarity_map(x_q, y_q, T4)
        score = score * torch.abs(S_I * S_Q)**lmbda
        # Complex gradients will work in PyTorch 1.6.0
        # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)

    result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3])

    if reduction == 'none':
        return result

    return {'mean': result.mean, 'sum': result.sum}[reduction](dim=0)
Beispiel #14
0
def multi_scale_gmsd(
    prediction: torch.Tensor,
    target: torch.Tensor,
    data_range: Union[int, float] = 1.,
    reduction: str = 'mean',
    scale_weights: Optional[Union[torch.Tensor, Tuple[float, ...],
                                  List[float]]] = None,
    chromatic: bool = False,
    beta1: float = 0.01,
    beta2: float = 0.32,
    beta3: float = 15.,
    t: float = 170 / (255.**2)) -> torch.Tensor:
    r"""Computation of Multi scale GMSD.

    Args:
        prediction: Tensor of prediction of the network. The height and width should be at least 2 ** scales + 1.
        target: Reference tensor. The height and width should be at least 2 ** scales + 1.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
        scale_weights: Weights for different scales. Can contain any number of floating point values.
        chromatic: Flag to use MS-GMSDc algorithm from paper.
            It also evaluates chromatic components of the image. Default: True
        beta1: Algorithm parameter. Weight of chromatic component in the loss.
        beta2: Algorithm parameter. Small constant, see [1].
        beta3: Algorithm parameter. Small constant, see [1].
        t: Constant from the reference paper numerical stability of similarity map

    Returns:
        Value of MS-GMSD. 0 <= GMSD loss <= 1.
    """
    _validate_input(input_tensors=(prediction, target),
                    allow_5d=False,
                    scale_weights=scale_weights)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    # Values from the paper
    if scale_weights is None:
        scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019])
    elif isinstance(scale_weights, torch.Tensor):
        scale_weights = scale_weights / scale_weights.sum()
    else:
        # Normalize scale weights
        scale_weights = torch.tensor(scale_weights) / torch.tensor(
            scale_weights).sum()

    # Check that input is big enough
    num_scales = scale_weights.size(0)
    min_size = 2**num_scales + 1

    if prediction.size(-1) < min_size or prediction.size(-2) < min_size:
        raise ValueError(
            f'Invalid size of the input images, expected at least {min_size}x{min_size}.'
        )

    prediction = prediction / float(data_range)
    target = target / float(data_range)

    num_channels = prediction.size(1)
    if num_channels == 3:
        prediction = rgb2yiq(prediction)
        target = rgb2yiq(target)

    scale_weights = scale_weights.to(prediction)
    ms_gmds = []
    for scale in range(num_scales):
        if scale > 0:
            # Average by 2x2 filter and downsample
            up_pad = 0
            down_pad = max(prediction.shape[2] % 2, prediction.shape[3] % 2)
            pad_to_use = [up_pad, down_pad, up_pad, down_pad]
            prediction = F.pad(prediction, pad=pad_to_use)
            target = F.pad(target, pad=pad_to_use)
            prediction = F.avg_pool2d(prediction, kernel_size=2, padding=0)
            target = F.avg_pool2d(target, kernel_size=2, padding=0)

        score = _gmsd(prediction[:, :1], target[:, :1], t=t)
        ms_gmds.append(score)

    # Stack results in different scales and multiply by weight
    ms_gmds_val = scale_weights.view(1, num_scales) * (torch.stack(ms_gmds,
                                                                   dim=1)**2)

    # Sum and take sqrt per-image
    ms_gmds_val = torch.sqrt(torch.sum(ms_gmds_val, dim=1))

    # Shape: (batch_size, )
    score = ms_gmds_val

    if chromatic:
        assert prediction.size(
            1) == 3, "Chromatic component can be computed only for RGB images!"

        prediction_iq = prediction[:, 1:]
        target_iq = target[:, 1:]

        rmse_iq = torch.sqrt(
            torch.mean((prediction_iq - target_iq)**2, dim=[2, 3]))
        rmse_chrome = torch.sqrt(torch.sum(rmse_iq**2, dim=1))
        gamma = 2 / (1 + beta2 * torch.exp(-beta3 * ms_gmds_val)) - 1

        score = gamma * ms_gmds_val + (1 - gamma) * beta1 * rmse_chrome

    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #15
0
def test_breaks_if_number_of_dim_greater_five() -> None:
    tensor_6d = torch.rand(1, 1, 1, 1, 1, 1)
    with pytest.raises(ValueError):
        _adjust_dimensions(tensor_6d)
Beispiel #16
0
def vif_p(x: torch.Tensor,
          y: torch.Tensor,
          sigma_n_sq: float = 2.0,
          data_range: Union[int, float] = 1.0,
          reduction: str = 'mean') -> torch.Tensor:
    r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images.
    This metric isn't symmetric, so make sure to place arguments in correct order.
    Both inputs supposed to have RGB channels order.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        sigma_n_sq: HVS model parameter (variance of the visual noise).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        
    Returns:
        VIF: Index of similarity betwen two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted images with higher contrast than original one.
    Note:
        In original paper this method was used for bands in discrete wavelet decomposition.
        Later on authors released code to compute VIF approximation in pixel domain.
        See https://live.ece.utexas.edu/research/Quality/VIF.htm for details.
        
    """
    _validate_input((x, y), allow_5d=False, data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    min_size = 41
    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(
            f'Invalid size of the input images, expected at least {min_size}x{min_size}.'
        )

    x = x / data_range * 255
    y = y / data_range * 255

    # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
    num_channels = x.size(1)
    if num_channels == 3:
        x = 0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:,
                                                                      2, :, :]
        y = 0.299 * y[:, 0, :, :] + 0.587 * y[:, 1, :, :] + 0.114 * y[:,
                                                                      2, :, :]

        # Add channel dimension
        x = x[:, None, :, :]
        y = y[:, None, :, :]

    # Constant for numerical stability
    EPS = 1e-8

    # Progressively downsample images and compute VIF on different scales
    x_vif, y_vif = 0, 0
    for scale in range(4):
        kernel_size = 2**(4 - scale) + 1
        kernel = gaussian_filter(kernel_size, sigma=kernel_size / 5)
        kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x)

        if scale > 0:
            # Convolve and downsample
            x = F.conv2d(x, kernel)[:, :, ::2, ::2]  # valid padding
            y = F.conv2d(y, kernel)[:, :, ::2, ::2]  # valid padding

        mu_x, mu_y = F.conv2d(x, kernel), F.conv2d(y, kernel)  # valid padding
        mu_x_sq, mu_y_sq, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y

        # Good
        sigma_x_sq = F.conv2d(x**2, kernel) - mu_x_sq
        sigma_y_sq = F.conv2d(y**2, kernel) - mu_y_sq
        sigma_xy = F.conv2d(x * y, kernel) - mu_xy

        # Zero small negative values
        sigma_x_sq = torch.relu(sigma_x_sq)
        sigma_y_sq = torch.relu(sigma_y_sq)

        g = sigma_xy / (sigma_y_sq + EPS)
        sigma_v_sq = sigma_x_sq - g * sigma_xy

        g = torch.where(sigma_y_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_y_sq >= EPS, sigma_v_sq, sigma_x_sq)
        sigma_y_sq = torch.where(sigma_y_sq >= EPS, sigma_y_sq,
                                 torch.zeros_like(sigma_y_sq))

        g = torch.where(sigma_x_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_x_sq >= EPS, sigma_v_sq,
                                 torch.zeros_like(sigma_v_sq))

        sigma_v_sq = torch.where(g >= 0, sigma_v_sq, sigma_x_sq)
        g = torch.relu(g)

        sigma_v_sq = torch.where(sigma_v_sq > EPS, sigma_v_sq,
                                 torch.ones_like(sigma_v_sq) * EPS)

        x_vif_scale = torch.log10(1.0 + (g**2.) * sigma_y_sq /
                                  (sigma_v_sq + sigma_n_sq))
        x_vif = x_vif + torch.sum(x_vif_scale, dim=[1, 2, 3])
        y_vif = y_vif + torch.sum(torch.log10(1.0 + sigma_y_sq / sigma_n_sq),
                                  dim=[1, 2, 3])

    score: torch.Tensor = (x_vif + EPS) / (y_vif + EPS)

    # Reduce if needed
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #17
0
def vsi(prediction: torch.Tensor,
        target: torch.Tensor,
        reduction: str = 'mean',
        data_range: Union[int, float] = 1.,
        c1: float = 1.27,
        c2: float = 386.,
        c3: float = 130.,
        alpha: float = 0.4,
        beta: float = 0.02,
        omega_0: float = 0.021,
        sigma_f: float = 1.34,
        sigma_d: float = 145.,
        sigma_c: float = 0.001) -> torch.Tensor:
    r"""Compute Visual Saliency-induced Index for a batch of images.

    Both inputs are supposed to have RGB channels order in accordance with the original approach.
    Nevertheless, the method supports greyscale images, which they are converted to RGB by copying the grey
    channel 3 times.

    Args:
        prediction:  Tensor with shape (H, W), (C, H, W) or (N, C, H, W) holding a distorted image.
        target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W) holding a target image.
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        c1: coefficient to calculate saliency component of VSI
        c2: coefficient to calculate gradient component of VSI
        c3: coefficient to calculate color component of VSI
        alpha: power for gradient component of VSI
        beta: power for color component of VSI
        omega_0: coefficient to get log Gabor filter at SDSP
        sigma_f: coefficient to get log Gabor filter at SDSP
        sigma_d: coefficient to get SDSP
        sigma_c: coefficient to get SDSP

    Returns:
        VSI: Index of similarity between two images. Usually in [0, 1] interval.

    Shape:
        - Input:  Required to be 2D (H, W), 3D (C, H, W) or 4D (N, C, H, W). RGB channel order for colour images.
        - Target: Required to be 2D (H, W), 3D (C, H, W) or 4D (N, C, H, W). RGB channel order for colour images.

    References:
        .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           DOI:`10.1109/TIP.2003.819861`

    Note:
        The original method supports only RGB image.
        See https://ieeexplore.ieee.org/document/6873260 for details.
    """
    _validate_input(input_tensors=(prediction, target), allow_5d=False)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))
    if prediction.size(1) == 1:
        prediction = prediction.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
        warnings.warn(
            'The original VSI supports only RGB images. The input images were converted to RGB by copying '
            'the grey channel 3 times.')

    # Scale to [0, 255] range to match scale of constant
    prediction = prediction * 255. / data_range
    target = target * 255. / data_range

    vs_prediction = sdsp(prediction,
                         data_range=255,
                         omega_0=omega_0,
                         sigma_f=sigma_f,
                         sigma_d=sigma_d,
                         sigma_c=sigma_c)
    vs_target = sdsp(target,
                     data_range=255,
                     omega_0=omega_0,
                     sigma_f=sigma_f,
                     sigma_d=sigma_d,
                     sigma_c=sigma_c)

    # Convert to LMN colour space
    prediction_lmn = rgb2lmn(prediction)
    target_lmn = rgb2lmn(target)

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(vs_prediction.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        upper_pad = padding
        bottom_pad = (kernel_size - 1) // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        mode = 'replicate'
        vs_prediction = pad(vs_prediction, pad=pad_to_use, mode=mode)
        vs_target = pad(vs_target, pad=pad_to_use, mode=mode)
        prediction_lmn = pad(prediction_lmn, pad=pad_to_use, mode=mode)
        target_lmn = pad(target_lmn, pad=pad_to_use, mode=mode)

    vs_prediction = avg_pool2d(vs_prediction, kernel_size=kernel_size)
    vs_target = avg_pool2d(vs_target, kernel_size=kernel_size)

    prediction_lmn = avg_pool2d(prediction_lmn, kernel_size=kernel_size)
    target_lmn = avg_pool2d(target_lmn, kernel_size=kernel_size)

    # Calculate gradient map
    kernels = torch.stack([scharr_filter(),
                           scharr_filter().transpose(1, 2)]).to(prediction_lmn)
    gm_prediction = gradient_map(prediction_lmn[:, :1], kernels)
    gm_target = gradient_map(target_lmn[:, :1], kernels)

    # Calculate all similarity maps
    s_vs = similarity_map(vs_prediction, vs_target, c1)
    s_gm = similarity_map(gm_prediction, gm_target, c2)
    s_m = similarity_map(prediction_lmn[:, 1:2], target_lmn[:, 1:2], c3)
    s_n = similarity_map(prediction_lmn[:, 2:], target_lmn[:, 2:], c3)
    s_c = s_m * s_n

    s_c_complex = [s_c.abs(), torch.atan2(torch.zeros_like(s_c), s_c)]
    s_c_complex_pow = [s_c_complex[0]**beta, s_c_complex[1] * beta]
    s_c_real_pow = s_c_complex_pow[0] * torch.cos(s_c_complex_pow[1])

    s = s_vs * s_gm.pow(alpha) * s_c_real_pow
    vs_max = torch.max(vs_prediction, vs_target)

    eps = torch.finfo(vs_max.dtype).eps
    output = s * vs_max
    output = ((output.sum(dim=(-1, -2)) + eps) /
              (vs_max.sum(dim=(-1, -2)) + eps)).squeeze(-1)
    if reduction == 'none':
        return output
    return {'mean': torch.mean, 'sum': torch.sum}[reduction](output, dim=0)
Beispiel #18
0
def mdsi(x: torch.Tensor,
         y: torch.Tensor,
         data_range: Union[int, float] = 1.,
         reduction: str = 'mean',
         c1: float = 140.,
         c2: float = 55.,
         c3: float = 550.,
         combination: str = 'sum',
         alpha: float = 0.6,
         beta: float = 0.1,
         gamma: float = 0.2,
         rho: float = 1.,
         q: float = 0.25,
         o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.

    Note:
        Both inputs are supposed to have RGB channels order.
        Greyscale images converted to RGB by copying the grey channel 3 times.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y:Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly

    Note:
        The ratio between constants is usually equal c3 = 4c1 = 10c2
    """
    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    if x.size(1) == 1:
        x = x.repeat(1, 3, 1, 1)
        y = y.repeat(1, 3, 1, 1)
        warnings.warn(
            'The original MDSI supports only RGB images. The input images were converted to RGB by copying '
            'the grey channel 3 times.')

    x = x / data_range * 255
    y = y / data_range * 255

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(x.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = pad(x, pad=pad_to_use)
        y = pad(y, pad=pad_to_use)

    x = avg_pool2d(x, kernel_size=kernel_size)
    y = avg_pool2d(y, kernel_size=kernel_size)

    x_lhm = rgb2lhm(x)
    y_lhm = rgb2lhm(y)

    kernels = torch.stack([prewitt_filter(),
                           prewitt_filter().transpose(1, 2)]).to(x)
    gm_x = gradient_map(x_lhm[:, :1], kernels)
    gm_y = gradient_map(y_lhm[:, :1], kernels)
    gm_avg = gradient_map((x_lhm[:, :1] + y_lhm[:, :1]) / 2., kernels)

    gs_x_y = similarity_map(gm_x, gm_y, c1)
    gs_x_average = similarity_map(gm_x, gm_avg, c2)
    gs_y_average = similarity_map(gm_y, gm_avg, c2)

    gs_total = gs_x_y + gs_x_average - gs_y_average

    cs_total = (2 *
                (x_lhm[:, 1:2] * y_lhm[:, 1:2] + x_lhm[:, 2:] * y_lhm[:, 2:]) +
                c3) / (x_lhm[:, 1:2]**2 + y_lhm[:, 1:2]**2 + x_lhm[:, 2:]**2 +
                       y_lhm[:, 2:]**2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]),
                          dim=-1)
    else:
        raise ValueError(
            f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(
        dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) -
             mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score**rho).mean(dim=(-1, -2))**(o / rho)).squeeze(1)
    if reduction == 'none':
        return score
    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Beispiel #19
0
def haarpsi(x: torch.Tensor, y: torch.Tensor, reduction: Optional[str] = 'mean',
            data_range: Union[int, float] = 1., scales: int = 3, subsample: bool = True,
            c: float = 30.0, alpha: float = 4.2) -> torch.Tensor:
    r"""Compute Haar Wavelet-Based Perceptual Similarity
    Input can by greyscale tensor of colour image with RGB channels order.
    Args:
        x: Tensor of shape :math:`(N, C, H, W)` holding an distorted image.
        y: Tensor of shape :math:`(N, C, H, W)` holding an target image
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        scales: Number of Haar wavelets used for image decomposition.
        subsample: Flag to apply average pooling before HaarPSI computation. See [1] for details.
        c: Constant from the paper. See [1] for details
        alpha: Exponent used for similarity maps weightning. See [1] for details

    Returns:
        HaarPSI : Wavelet-Based Perceptual Similarity between two tensors
    
    References:
        [1] R. Reisenhofer, S. Bosse, G. Kutyniok & T. Wiegand (2017)
            'A Haar Wavelet-Based Perceptual Similarity Index for Image Quality Assessment'
            http://www.math.uni-bremen.de/cda/HaarPSI/publications/HaarPSI_preprint_v4.pdf
        [2] Code from authors on MATLAB and Python
            https://github.com/rgcda/haarpsi
    """

    _validate_input(input_tensors=(x, y), allow_5d=False, scale_weights=None)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Assert minimal image size
    kernel_size = 2 ** (scales + 1)
    if x.size(-1) < kernel_size or x.size(-2) < kernel_size:
        raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
                         f'Kernel size: {kernel_size}')

    # Scale images to [0, 255] range as in the paper
    x = x * 255.0 / float(data_range)
    y = y * 255.0 / float(data_range)

    num_channels = x.size(1)
    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)
    else:
        x_yiq = x
        y_yiq = y

    # Downscale input to simulates the typical distance between an image and its viewer.
    if subsample:
        up_pad = 0
        down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)

        x_yiq = F.avg_pool2d(x_yiq, kernel_size=2, stride=2, padding=0)
        y_yiq = F.avg_pool2d(y_yiq, kernel_size=2, stride=2, padding=0)
    
    # Haar wavelet decomposition
    coefficients_x, coefficients_y = [], []
    for scale in range(scales):
        kernel_size = 2 ** (scale + 1)
        kernels = torch.stack([haar_filter(kernel_size), haar_filter(kernel_size).transpose(-1, -2)])
    
        # Assymetrical padding due to even kernel size. Matches MATLAB conv2(A, B, 'same')
        upper_pad = kernel_size // 2 - 1
        bottom_pad = kernel_size // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        coeff_x = torch.nn.functional.conv2d(F.pad(x_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(x))
        coeff_y = torch.nn.functional.conv2d(F.pad(y_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(y))
    
        coefficients_x.append(coeff_x)
        coefficients_y.append(coeff_y)

    # Shape [B x {scales * 2} x H x W]
    coefficients_x = torch.cat(coefficients_x, dim=1)
    coefficients_y = torch.cat(coefficients_y, dim=1)

    # Low-frequency coefficients used as weights
    # Shape [B x 2 x H x W]
    weights = torch.max(torch.abs(coefficients_x[:, 4:]), torch.abs(coefficients_y[:, 4:]))
    
    # High-frequency coefficients used for similarity computation in 2 orientations (horizontal and vertical)
    sim_map = []
    for orientation in range(2):
        magnitude_x = torch.abs(coefficients_x[:, (orientation, orientation + 2)])
        magnitude_y = torch.abs(coefficients_y[:, (orientation, orientation + 2)])
        sim_map.append(similarity_map(magnitude_x, magnitude_y, constant=c).sum(dim=1, keepdims=True) / 2)

    if num_channels == 3:
        pad_to_use = [0, 1, 0, 1]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)
        coefficients_x_iq = torch.abs(F.avg_pool2d(x_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
        coefficients_y_iq = torch.abs(F.avg_pool2d(y_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
    
        # Compute weights and simmilarity
        weights = torch.cat([weights, weights.mean(dim=1, keepdims=True)], dim=1)
        sim_map.append(
            similarity_map(coefficients_x_iq, coefficients_y_iq, constant=c).sum(dim=1, keepdims=True) / 2)

    sim_map = torch.cat(sim_map, dim=1)
    
    # Calculate the final score
    eps = torch.finfo(sim_map.dtype).eps
    score = (((sim_map * alpha).sigmoid() * weights).sum(dim=[1, 2, 3]) + eps) /\
        (torch.sum(weights, dim=[1, 2, 3]) + eps)
    # Logit of score
    score = (torch.log(score / (1 - score)) / alpha) ** 2

    if reduction == 'none':
        return score

    return {'mean': score.mean,
            'sum': score.sum
            }[reduction](dim=0)
Beispiel #20
0
def mdsi(prediction: torch.Tensor, target: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
         c1: float = 140., c2: float = 55., c3: float = 550., combination: str = 'sum', alpha: float = 0.6,
         beta: float = 0.1, gamma: float = 0.2, rho: float = 1., q: float = 0.25, o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.

    Note:
        Both inputs are supposed to have RGB order in accordance with the original approach.
        Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
        channel 3 times.

    Args:
        prediction: Batch of predicted (distorted) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W),
        channels first.
        target: Batch of target (reference) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly

    Note:
        The ratio between constants is usually equal c3 = 4c1 = 10c2
    """
    _validate_input(input_tensors=(prediction, target), allow_5d=False)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    if prediction.size(1) == 1:
        prediction = prediction.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
        warnings.warn('The original MDSI supports only RGB images. The input images were converted to RGB by copying '
                      'the grey channel 3 times.')

    prediction = prediction * 255. / data_range
    target = target * 255. / data_range

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(prediction.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        prediction = pad(prediction, pad=pad_to_use)
        target = pad(target, pad=pad_to_use)

    prediction = avg_pool2d(prediction, kernel_size=kernel_size)
    target = avg_pool2d(target, kernel_size=kernel_size)

    prediction_lhm = rgb2lhm(prediction)
    target_lhm = rgb2lhm(target)

    kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(1, 2)]).to(prediction)
    gm_prediction = gradient_map(prediction_lhm[:, :1], kernels)
    gm_target = gradient_map(target_lhm[:, :1], kernels)
    gm_avg = gradient_map((prediction_lhm[:, :1] + target_lhm[:, :1]) / 2., kernels)

    gs_prediction_target = similarity_map(gm_prediction, gm_target, c1)
    gs_prediction_average = similarity_map(gm_prediction, gm_avg, c2)
    gs_target_average = similarity_map(gm_target, gm_avg, c2)

    gs_total = gs_prediction_target + gs_prediction_average - gs_target_average

    cs_total = (2 * (prediction_lhm[:, 1:2] * target_lhm[:, 1:2] +
                     prediction_lhm[:, 2:] * target_lhm[:, 2:]) + c3) / (prediction_lhm[:, 1:2] ** 2 +
                                                                         target_lhm[:, 1:2] ** 2 +
                                                                         prediction_lhm[:, 2:] ** 2 +
                                                                         target_lhm[:, 2:] ** 2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]), dim=-1)
    else:
        raise ValueError(f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) - mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score ** rho).mean(dim=(-1, -2)) ** (o / rho)).squeeze(1)
    if reduction == 'none':
        return score
    return {'mean': score.mean,
            'sum': score.sum}[reduction](dim=0)