def compute_distance(self, x_features: torch.Tensor,
                         y_features: torch.Tensor) -> List[torch.Tensor]:
        r"""Compute structure similarity between feature maps

        Args:
            x_features: Features of the input tensor.
            y_features: Features of the target tensor.

        Returns:
            Structural similarity distance between feature maps
        """
        structure_distance, texture_distance = [], []
        # Small constant for numerical stability
        EPS = 1e-6

        for x, y in zip(x_features, y_features):
            x_mean = x.mean([2, 3], keepdim=True)
            y_mean = y.mean([2, 3], keepdim=True)
            structure_distance.append(
                similarity_map(x_mean, y_mean, constant=EPS))

            x_var = ((x - x_mean)**2).mean([2, 3], keepdim=True)
            y_var = ((y - y_mean)**2).mean([2, 3], keepdim=True)
            xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
            texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))

        return structure_distance + texture_distance
Exemple #2
0
def _gmsd(prediction: torch.Tensor,
          target: torch.Tensor,
          t: float = 170 / (255.**2)) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation
    Both inputs supposed to be in range [0, 1] with RGB order.
    Args:
        prediction: Tensor of shape :math:`(N, 1, H, W)` holding an distorted grayscale image.
        target: Tensor of shape :math:`(N, 1, H, W)` holding an target grayscale image
        t: Constant from the reference paper numerical stability of similarity map

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        https://arxiv.org/pdf/1308.3052.pdf
    """

    # Compute grad direction
    kernels = torch.stack(
        [prewitt_filter(),
         prewitt_filter().transpose(-1, -2)])
    pred_grad = gradient_map(prediction, kernels)
    trgt_grad = gradient_map(target, kernels)

    # Compute GMS
    gms = similarity_map(pred_grad, trgt_grad, t)
    mean_gms = torch.mean(gms, dim=[1, 2, 3], keepdims=True)
    # Compute GMSD along spatial dimensions. Shape (batch_size )
    score = torch.pow(gms - mean_gms, 2).mean(dim=[1, 2, 3]).sqrt()
    return score
Exemple #3
0
def _gmsd(x: torch.Tensor,
          y: torch.Tensor,
          t: float = 170 / (255.**2),
          alpha: float = 0.0) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation
    Both inputs supposed to be in range [0, 1] with RGB channels order.
    Args:
        x: Tensor with shape (N, 1, H, W).
        y: Tensor with shape (N, 1, H, W).
        t: Constant from the reference paper numerical stability of similarity map
        alpha: Masking coefficient for similarity masks computation

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        https://arxiv.org/pdf/1308.3052.pdf
    """

    # Compute grad direction
    kernels = torch.stack(
        [prewitt_filter(),
         prewitt_filter().transpose(-1, -2)])
    x_grad = gradient_map(x, kernels)
    y_grad = gradient_map(y, kernels)

    # Compute GMS
    gms = similarity_map(x_grad, y_grad, constant=t, alpha=alpha)
    mean_gms = torch.mean(gms, dim=[1, 2, 3], keepdims=True)

    # Compute GMSD along spatial dimensions. Shape (batch_size )
    score = torch.pow(gms - mean_gms, 2).mean(dim=[1, 2, 3]).sqrt()
    return score
Exemple #4
0
def haarpsi(x: torch.Tensor, y: torch.Tensor, reduction: Optional[str] = 'mean',
            data_range: Union[int, float] = 1., scales: int = 3, subsample: bool = True,
            c: float = 30.0, alpha: float = 4.2) -> torch.Tensor:
    r"""Compute Haar Wavelet-Based Perceptual Similarity
    Input can by greyscale tensor of colour image with RGB channels order.
    Args:
        x: Tensor of shape :math:`(N, C, H, W)` holding an distorted image.
        y: Tensor of shape :math:`(N, C, H, W)` holding an target image
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        scales: Number of Haar wavelets used for image decomposition.
        subsample: Flag to apply average pooling before HaarPSI computation. See [1] for details.
        c: Constant from the paper. See [1] for details
        alpha: Exponent used for similarity maps weightning. See [1] for details

    Returns:
        HaarPSI : Wavelet-Based Perceptual Similarity between two tensors
    
    References:
        [1] R. Reisenhofer, S. Bosse, G. Kutyniok & T. Wiegand (2017)
            'A Haar Wavelet-Based Perceptual Similarity Index for Image Quality Assessment'
            http://www.math.uni-bremen.de/cda/HaarPSI/publications/HaarPSI_preprint_v4.pdf
        [2] Code from authors on MATLAB and Python
            https://github.com/rgcda/haarpsi
    """

    _validate_input(input_tensors=(x, y), allow_5d=False, scale_weights=None)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Assert minimal image size
    kernel_size = 2 ** (scales + 1)
    if x.size(-1) < kernel_size or x.size(-2) < kernel_size:
        raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
                         f'Kernel size: {kernel_size}')

    # Scale images to [0, 255] range as in the paper
    x = x * 255.0 / float(data_range)
    y = y * 255.0 / float(data_range)

    num_channels = x.size(1)
    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)
    else:
        x_yiq = x
        y_yiq = y

    # Downscale input to simulates the typical distance between an image and its viewer.
    if subsample:
        up_pad = 0
        down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)

        x_yiq = F.avg_pool2d(x_yiq, kernel_size=2, stride=2, padding=0)
        y_yiq = F.avg_pool2d(y_yiq, kernel_size=2, stride=2, padding=0)
    
    # Haar wavelet decomposition
    coefficients_x, coefficients_y = [], []
    for scale in range(scales):
        kernel_size = 2 ** (scale + 1)
        kernels = torch.stack([haar_filter(kernel_size), haar_filter(kernel_size).transpose(-1, -2)])
    
        # Assymetrical padding due to even kernel size. Matches MATLAB conv2(A, B, 'same')
        upper_pad = kernel_size // 2 - 1
        bottom_pad = kernel_size // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        coeff_x = torch.nn.functional.conv2d(F.pad(x_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(x))
        coeff_y = torch.nn.functional.conv2d(F.pad(y_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(y))
    
        coefficients_x.append(coeff_x)
        coefficients_y.append(coeff_y)

    # Shape [B x {scales * 2} x H x W]
    coefficients_x = torch.cat(coefficients_x, dim=1)
    coefficients_y = torch.cat(coefficients_y, dim=1)

    # Low-frequency coefficients used as weights
    # Shape [B x 2 x H x W]
    weights = torch.max(torch.abs(coefficients_x[:, 4:]), torch.abs(coefficients_y[:, 4:]))
    
    # High-frequency coefficients used for similarity computation in 2 orientations (horizontal and vertical)
    sim_map = []
    for orientation in range(2):
        magnitude_x = torch.abs(coefficients_x[:, (orientation, orientation + 2)])
        magnitude_y = torch.abs(coefficients_y[:, (orientation, orientation + 2)])
        sim_map.append(similarity_map(magnitude_x, magnitude_y, constant=c).sum(dim=1, keepdims=True) / 2)

    if num_channels == 3:
        pad_to_use = [0, 1, 0, 1]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)
        coefficients_x_iq = torch.abs(F.avg_pool2d(x_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
        coefficients_y_iq = torch.abs(F.avg_pool2d(y_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
    
        # Compute weights and simmilarity
        weights = torch.cat([weights, weights.mean(dim=1, keepdims=True)], dim=1)
        sim_map.append(
            similarity_map(coefficients_x_iq, coefficients_y_iq, constant=c).sum(dim=1, keepdims=True) / 2)

    sim_map = torch.cat(sim_map, dim=1)
    
    # Calculate the final score
    eps = torch.finfo(sim_map.dtype).eps
    score = (((sim_map * alpha).sigmoid() * weights).sum(dim=[1, 2, 3]) + eps) /\
        (torch.sum(weights, dim=[1, 2, 3]) + eps)
    # Logit of score
    score = (torch.log(score / (1 - score)) / alpha) ** 2

    if reduction == 'none':
        return score

    return {'mean': score.mean,
            'sum': score.sum
            }[reduction](dim=0)
Exemple #5
0
def fsim(x: torch.Tensor,
         y: torch.Tensor,
         reduction: str = 'mean',
         data_range: Union[int, float] = 1.0,
         chromatic: bool = True,
         scales: int = 4,
         orientations: int = 4,
         min_length: int = 6,
         mult: int = 2,
         sigma_f: float = 0.55,
         delta_theta: float = 1.2,
         k: float = 2.0) -> torch.Tensor:
    r"""Compute Feature Similarity Index Measure for a batch of images.

    Args:
        x: Predicted images set :math:`x`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        y: Target images set :math:`y`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        chromatic: Flag to compute FSIMc, which also takes into account chromatic components
        scales: Number of wavelets used for computation of phase congruensy maps
        orientations: Number of filter orientations used for computation of phase congruensy maps
        min_length: Wavelength of smallest scale filter
        mult: Scaling factor between successive filters
        sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's
            transfer function in the frequency domain to the filter center frequency.
        delta_theta: Ratio of angular interval between filter orientations and the standard deviation
            of the angular Gaussian function used to construct filters in the frequency plane.
        k: No of standard deviations of the noise energy beyond the mean at which we set the noise
            threshold  point, below which phase congruency values get penalized.
        
    Returns:
        FSIM: Index of similarity betwen two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted (x) images with higher contrast than the original ones.
    Note:
        This implementation is based on the original MATLAB code.
        https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm
        
    """

    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = x / data_range * 255
    y = y / data_range * 255

    # Apply average pooling
    kernel_size = max(1, round(min(x.shape[-2:]) / 256))
    x = torch.nn.functional.avg_pool2d(x, kernel_size)
    y = torch.nn.functional.avg_pool2d(y, kernel_size)

    num_channels = x.size(1)

    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)

        x_lum = x_yiq[:, :1]
        y_lum = y_yiq[:, :1]

        x_i = x_yiq[:, 1:2]
        y_i = y_yiq[:, 1:2]
        x_q = x_yiq[:, 2:]
        y_q = y_yiq[:, 2:]

    else:
        x_lum = x
        y_lum = y

    # Compute phase congruency maps
    pc_x = _phase_congruency(x_lum,
                             scales=scales,
                             orientations=orientations,
                             min_length=min_length,
                             mult=mult,
                             sigma_f=sigma_f,
                             delta_theta=delta_theta,
                             k=k)
    pc_y = _phase_congruency(y_lum,
                             scales=scales,
                             orientations=orientations,
                             min_length=min_length,
                             mult=mult,
                             sigma_f=sigma_f,
                             delta_theta=delta_theta,
                             k=k)

    # Gradient maps
    kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
    grad_map_x = gradient_map(x_lum, kernels)
    grad_map_y = gradient_map(y_lum, kernels)

    # Constants from the paper
    T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03

    # Compute FSIM
    PC = similarity_map(pc_x, pc_y, T1)
    GM = similarity_map(grad_map_x, grad_map_y, T2)
    pc_max = torch.where(pc_x > pc_y, pc_x, pc_y)
    score = GM * PC * pc_max

    if chromatic:
        assert num_channels == 3, "Chromatic component can be computed only for RGB images!"
        S_I = similarity_map(x_i, y_i, T3)
        S_Q = similarity_map(x_q, y_q, T4)
        score = score * torch.abs(S_I * S_Q)**lmbda
        # Complex gradients will work in PyTorch 1.6.0
        # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)

    result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3])

    if reduction == 'none':
        return result

    return {'mean': result.mean, 'sum': result.sum}[reduction](dim=0)
Exemple #6
0
def vsi(prediction: torch.Tensor,
        target: torch.Tensor,
        reduction: str = 'mean',
        data_range: Union[int, float] = 1.,
        c1: float = 1.27,
        c2: float = 386.,
        c3: float = 130.,
        alpha: float = 0.4,
        beta: float = 0.02,
        omega_0: float = 0.021,
        sigma_f: float = 1.34,
        sigma_d: float = 145.,
        sigma_c: float = 0.001) -> torch.Tensor:
    r"""Compute Visual Saliency-induced Index for a batch of images.

    Both inputs are supposed to have RGB channels order in accordance with the original approach.
    Nevertheless, the method supports greyscale images, which they are converted to RGB by copying the grey
    channel 3 times.

    Args:
        prediction:  Tensor with shape (H, W), (C, H, W) or (N, C, H, W) holding a distorted image.
        target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W) holding a target image.
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        c1: coefficient to calculate saliency component of VSI
        c2: coefficient to calculate gradient component of VSI
        c3: coefficient to calculate color component of VSI
        alpha: power for gradient component of VSI
        beta: power for color component of VSI
        omega_0: coefficient to get log Gabor filter at SDSP
        sigma_f: coefficient to get log Gabor filter at SDSP
        sigma_d: coefficient to get SDSP
        sigma_c: coefficient to get SDSP

    Returns:
        VSI: Index of similarity between two images. Usually in [0, 1] interval.

    Shape:
        - Input:  Required to be 2D (H, W), 3D (C, H, W) or 4D (N, C, H, W). RGB channel order for colour images.
        - Target: Required to be 2D (H, W), 3D (C, H, W) or 4D (N, C, H, W). RGB channel order for colour images.

    References:
        .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           DOI:`10.1109/TIP.2003.819861`

    Note:
        The original method supports only RGB image.
        See https://ieeexplore.ieee.org/document/6873260 for details.
    """
    _validate_input(input_tensors=(prediction, target), allow_5d=False)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))
    if prediction.size(1) == 1:
        prediction = prediction.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
        warnings.warn(
            'The original VSI supports only RGB images. The input images were converted to RGB by copying '
            'the grey channel 3 times.')

    # Scale to [0, 255] range to match scale of constant
    prediction = prediction * 255. / data_range
    target = target * 255. / data_range

    vs_prediction = sdsp(prediction,
                         data_range=255,
                         omega_0=omega_0,
                         sigma_f=sigma_f,
                         sigma_d=sigma_d,
                         sigma_c=sigma_c)
    vs_target = sdsp(target,
                     data_range=255,
                     omega_0=omega_0,
                     sigma_f=sigma_f,
                     sigma_d=sigma_d,
                     sigma_c=sigma_c)

    # Convert to LMN colour space
    prediction_lmn = rgb2lmn(prediction)
    target_lmn = rgb2lmn(target)

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(vs_prediction.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        upper_pad = padding
        bottom_pad = (kernel_size - 1) // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        mode = 'replicate'
        vs_prediction = pad(vs_prediction, pad=pad_to_use, mode=mode)
        vs_target = pad(vs_target, pad=pad_to_use, mode=mode)
        prediction_lmn = pad(prediction_lmn, pad=pad_to_use, mode=mode)
        target_lmn = pad(target_lmn, pad=pad_to_use, mode=mode)

    vs_prediction = avg_pool2d(vs_prediction, kernel_size=kernel_size)
    vs_target = avg_pool2d(vs_target, kernel_size=kernel_size)

    prediction_lmn = avg_pool2d(prediction_lmn, kernel_size=kernel_size)
    target_lmn = avg_pool2d(target_lmn, kernel_size=kernel_size)

    # Calculate gradient map
    kernels = torch.stack([scharr_filter(),
                           scharr_filter().transpose(1, 2)]).to(prediction_lmn)
    gm_prediction = gradient_map(prediction_lmn[:, :1], kernels)
    gm_target = gradient_map(target_lmn[:, :1], kernels)

    # Calculate all similarity maps
    s_vs = similarity_map(vs_prediction, vs_target, c1)
    s_gm = similarity_map(gm_prediction, gm_target, c2)
    s_m = similarity_map(prediction_lmn[:, 1:2], target_lmn[:, 1:2], c3)
    s_n = similarity_map(prediction_lmn[:, 2:], target_lmn[:, 2:], c3)
    s_c = s_m * s_n

    s_c_complex = [s_c.abs(), torch.atan2(torch.zeros_like(s_c), s_c)]
    s_c_complex_pow = [s_c_complex[0]**beta, s_c_complex[1] * beta]
    s_c_real_pow = s_c_complex_pow[0] * torch.cos(s_c_complex_pow[1])

    s = s_vs * s_gm.pow(alpha) * s_c_real_pow
    vs_max = torch.max(vs_prediction, vs_target)

    eps = torch.finfo(vs_max.dtype).eps
    output = s * vs_max
    output = ((output.sum(dim=(-1, -2)) + eps) /
              (vs_max.sum(dim=(-1, -2)) + eps)).squeeze(-1)
    if reduction == 'none':
        return output
    return {'mean': torch.mean, 'sum': torch.sum}[reduction](output, dim=0)
Exemple #7
0
def vsi(x: torch.Tensor,
        y: torch.Tensor,
        reduction: str = 'mean',
        data_range: Union[int, float] = 1.,
        c1: float = 1.27,
        c2: float = 386.,
        c3: float = 130.,
        alpha: float = 0.4,
        beta: float = 0.02,
        omega_0: float = 0.021,
        sigma_f: float = 1.34,
        sigma_d: float = 145.,
        sigma_c: float = 0.001) -> torch.Tensor:
    r"""Compute Visual Saliency-induced Index for a batch of images.

    Both inputs are supposed to have RGB channels order in accordance with the original approach.
    Nevertheless, the method supports greyscale images, which they are converted to RGB by copying the grey
    channel 3 times.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        c1: coefficient to calculate saliency component of VSI
        c2: coefficient to calculate gradient component of VSI
        c3: coefficient to calculate color component of VSI
        alpha: power for gradient component of VSI
        beta: power for color component of VSI
        omega_0: coefficient to get log Gabor filter at SDSP
        sigma_f: coefficient to get log Gabor filter at SDSP
        sigma_d: coefficient to get SDSP
        sigma_c: coefficient to get SDSP

    Returns:
        Index of similarity between two images. Usually in [0, 1] range.

    References:
        L. Zhang, Y. Shen and H. Li, "VSI: A Visual Saliency-Induced Index for Perceptual Image Quality Assessment,"
        IEEE Transactions on Image Processing, vol. 23, no. 10, pp. 4270-4281, Oct. 2014, doi: 10.1109/TIP.2014.2346028
        https://ieeexplore.ieee.org/document/6873260

    Note:
        The original method supports only RGB image.
        See https://ieeexplore.ieee.org/document/6873260 for details.
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    if x.size(1) == 1:
        x = x.repeat(1, 3, 1, 1)
        y = y.repeat(1, 3, 1, 1)
        warnings.warn(
            'The original VSI supports only RGB images. The input images were converted to RGB by copying '
            'the grey channel 3 times.')

    # Scale to [0, 255] range to match scale of constant
    x = x * 255. / float(data_range)
    y = y * 255. / float(data_range)

    vs_x = sdsp(x,
                data_range=255,
                omega_0=omega_0,
                sigma_f=sigma_f,
                sigma_d=sigma_d,
                sigma_c=sigma_c)
    vs_y = sdsp(y,
                data_range=255,
                omega_0=omega_0,
                sigma_f=sigma_f,
                sigma_d=sigma_d,
                sigma_c=sigma_c)

    # Convert to LMN colour space
    x_lmn = rgb2lmn(x)
    y_lmn = rgb2lmn(y)

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(vs_x.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        upper_pad = padding
        bottom_pad = (kernel_size - 1) // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        mode = 'replicate'
        vs_x = pad(vs_x, pad=pad_to_use, mode=mode)
        vs_y = pad(vs_y, pad=pad_to_use, mode=mode)
        x_lmn = pad(x_lmn, pad=pad_to_use, mode=mode)
        y_lmn = pad(y_lmn, pad=pad_to_use, mode=mode)

    vs_x = avg_pool2d(vs_x, kernel_size=kernel_size)
    vs_y = avg_pool2d(vs_y, kernel_size=kernel_size)

    x_lmn = avg_pool2d(x_lmn, kernel_size=kernel_size)
    y_lmn = avg_pool2d(y_lmn, kernel_size=kernel_size)

    # Calculate gradient map
    kernels = torch.stack([scharr_filter(),
                           scharr_filter().transpose(1, 2)]).to(x_lmn)
    gm_x = gradient_map(x_lmn[:, :1], kernels)
    gm_y = gradient_map(y_lmn[:, :1], kernels)

    # Calculate all similarity maps
    s_vs = similarity_map(vs_x, vs_y, c1)
    s_gm = similarity_map(gm_x, gm_y, c2)
    s_m = similarity_map(x_lmn[:, 1:2], y_lmn[:, 1:2], c3)
    s_n = similarity_map(x_lmn[:, 2:], y_lmn[:, 2:], c3)
    s_c = s_m * s_n

    s_c_complex = [s_c.abs(), torch.atan2(torch.zeros_like(s_c), s_c)]
    s_c_complex_pow = [s_c_complex[0]**beta, s_c_complex[1] * beta]
    s_c_real_pow = s_c_complex_pow[0] * torch.cos(s_c_complex_pow[1])

    s = s_vs * s_gm.pow(alpha) * s_c_real_pow
    vs_max = torch.max(vs_x, vs_y)

    eps = torch.finfo(vs_max.dtype).eps
    output = s * vs_max
    output = ((output.sum(dim=(-1, -2)) + eps) /
              (vs_max.sum(dim=(-1, -2)) + eps)).squeeze(-1)

    return _reduce(output, reduction)
Exemple #8
0
def haarpsi(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean',
            data_range: Union[int, float] = 1., scales: int = 3, subsample: bool = True,
            c: float = 30.0, alpha: float = 4.2) -> torch.Tensor:
    r"""Compute Haar Wavelet-Based Perceptual Similarity
    Inputs supposed to be in range [0, data_range] with RGB channels order for colour images.
    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        scales: Number of Haar wavelets used for image decomposition.
        subsample: Flag to apply average pooling before HaarPSI computation. See [1] for details.
        c: Constant from the paper. See [1] for details
        alpha: Exponent used for similarity maps weightning. See [1] for details

    Returns:
        HaarPSI : Wavelet-Based Perceptual Similarity between two tensors
    
    References:
        .. [1] R. Reisenhofer, S. Bosse, G. Kutyniok & T. Wiegand (2017)
           'A Haar Wavelet-Based Perceptual Similarity Index for Image Quality Assessment'
           http://www.math.uni-bremen.de/cda/HaarPSI/publications/HaarPSI_preprint_v4.pdf
        .. [2] Code from authors on MATLAB and Python
           https://github.com/rgcda/haarpsi
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    # Assert minimal image size
    kernel_size = 2 ** (scales + 1)
    if x.size(-1) < kernel_size or x.size(-2) < kernel_size:
        raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
                         f'Kernel size: {kernel_size}')

    # Rescale images
    x = x / data_range * 255
    y = y / data_range * 255

    num_channels = x.size(1)
    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)
    else:
        x_yiq = x
        y_yiq = y

    # Downscale input to simulates the typical distance between an image and its viewer.
    if subsample:
        up_pad = 0
        down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)

        x_yiq = F.avg_pool2d(x_yiq, kernel_size=2, stride=2, padding=0)
        y_yiq = F.avg_pool2d(y_yiq, kernel_size=2, stride=2, padding=0)
    
    # Haar wavelet decomposition
    coefficients_x, coefficients_y = [], []
    for scale in range(scales):
        kernel_size = 2 ** (scale + 1)
        kernels = torch.stack([haar_filter(kernel_size), haar_filter(kernel_size).transpose(-1, -2)])
    
        # Assymetrical padding due to even kernel size. Matches MATLAB conv2(A, B, 'same')
        upper_pad = kernel_size // 2 - 1
        bottom_pad = kernel_size // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        coeff_x = torch.nn.functional.conv2d(F.pad(x_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(x))
        coeff_y = torch.nn.functional.conv2d(F.pad(y_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(y))
    
        coefficients_x.append(coeff_x)
        coefficients_y.append(coeff_y)

    # Shape (N, {scales * 2}, H, W)
    coefficients_x = torch.cat(coefficients_x, dim=1)
    coefficients_y = torch.cat(coefficients_y, dim=1)

    # Low-frequency coefficients used as weights
    # Shape (N, 2, H, W)
    weights = torch.max(torch.abs(coefficients_x[:, 4:]), torch.abs(coefficients_y[:, 4:]))
    
    # High-frequency coefficients used for similarity computation in 2 orientations (horizontal and vertical)
    sim_map = []
    for orientation in range(2):
        magnitude_x = torch.abs(coefficients_x[:, (orientation, orientation + 2)])
        magnitude_y = torch.abs(coefficients_y[:, (orientation, orientation + 2)])
        sim_map.append(similarity_map(magnitude_x, magnitude_y, constant=c).sum(dim=1, keepdims=True) / 2)

    if num_channels == 3:
        pad_to_use = [0, 1, 0, 1]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)
        coefficients_x_iq = torch.abs(F.avg_pool2d(x_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
        coefficients_y_iq = torch.abs(F.avg_pool2d(y_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
    
        # Compute weights and simmilarity
        weights = torch.cat([weights, weights.mean(dim=1, keepdims=True)], dim=1)
        sim_map.append(
            similarity_map(coefficients_x_iq, coefficients_y_iq, constant=c).sum(dim=1, keepdims=True) / 2)

    sim_map = torch.cat(sim_map, dim=1)
    
    # Calculate the final score
    eps = torch.finfo(sim_map.dtype).eps
    score = (((sim_map * alpha).sigmoid() * weights).sum(dim=[1, 2, 3]) + eps) /\
        (torch.sum(weights, dim=[1, 2, 3]) + eps)
    # Logit of score
    score = (torch.log(score / (1 - score)) / alpha) ** 2

    return _reduce(score, reduction)
Exemple #9
0
def mdsi(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
         c1: float = 140., c2: float = 55., c3: float = 550., combination: str = 'sum', alpha: float = 0.6,
         beta: float = 0.1, gamma: float = 0.2, rho: float = 1., q: float = 0.25, o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.
    Supports greyscale and colour images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: ``'sum'`` | ``'mult'``.
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        Mean Deviation Similarity Index (MDSI) between 2 tensors.

    References:
        Nafchi, Hossein Ziaei and Shahkolaei, Atena and Hedjam, Rachid and Cheriet, Mohamed (2016).
        Mean deviation similarity index: Efficient and reliable full-reference image quality evaluator.
        IEEE Ieee Access, 4, 5579--5590.
        https://arxiv.org/pdf/1608.07433.pdf,
        DOI:`10.1109/ACCESS.2016.2604042`

    Note:
        The ratio between constants is usually equal :math:`c_3 = 4c_1 = 10c_2`

    Note:
        Both inputs are supposed to have RGB channels order in accordance with the original approach.
        Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
        channel 3 times.
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    if x.size(1) == 1:
        x = x.repeat(1, 3, 1, 1)
        y = y.repeat(1, 3, 1, 1)
        warnings.warn('The original MDSI supports only RGB images. The input images were converted to RGB by copying '
                      'the grey channel 3 times.')

    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(x.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = pad(x, pad=pad_to_use)
        y = pad(y, pad=pad_to_use)

    x = avg_pool2d(x, kernel_size=kernel_size)
    y = avg_pool2d(y, kernel_size=kernel_size)

    x_lhm = rgb2lhm(x)
    y_lhm = rgb2lhm(y)

    kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(1, 2)]).to(x)
    gm_x = gradient_map(x_lhm[:, :1], kernels)
    gm_y = gradient_map(y_lhm[:, :1], kernels)
    gm_avg = gradient_map((x_lhm[:, :1] + y_lhm[:, :1]) / 2., kernels)

    gs_x_y = similarity_map(gm_x, gm_y, c1)
    gs_x_average = similarity_map(gm_x, gm_avg, c2)
    gs_y_average = similarity_map(gm_y, gm_avg, c2)

    gs_total = gs_x_y + gs_x_average - gs_y_average

    cs_total = (2 * (x_lhm[:, 1:2] * y_lhm[:, 1:2] +
                     x_lhm[:, 2:] * y_lhm[:, 2:]) + c3) / (x_lhm[:, 1:2] ** 2 +
                                                           y_lhm[:, 1:2] ** 2 +
                                                           x_lhm[:, 2:] ** 2 +
                                                           y_lhm[:, 2:] ** 2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]), dim=-1)
    else:
        raise ValueError(f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) - mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score ** rho).mean(dim=(-1, -2)) ** (o / rho)).squeeze(1)
    return _reduce(score, reduction)
Exemple #10
0
def mdsi(x: torch.Tensor,
         y: torch.Tensor,
         data_range: Union[int, float] = 1.,
         reduction: str = 'mean',
         c1: float = 140.,
         c2: float = 55.,
         c3: float = 550.,
         combination: str = 'sum',
         alpha: float = 0.6,
         beta: float = 0.1,
         gamma: float = 0.2,
         rho: float = 1.,
         q: float = 0.25,
         o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.

    Note:
        Both inputs are supposed to have RGB channels order.
        Greyscale images converted to RGB by copying the grey channel 3 times.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y:Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly

    Note:
        The ratio between constants is usually equal c3 = 4c1 = 10c2
    """
    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    if x.size(1) == 1:
        x = x.repeat(1, 3, 1, 1)
        y = y.repeat(1, 3, 1, 1)
        warnings.warn(
            'The original MDSI supports only RGB images. The input images were converted to RGB by copying '
            'the grey channel 3 times.')

    x = x / data_range * 255
    y = y / data_range * 255

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(x.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = pad(x, pad=pad_to_use)
        y = pad(y, pad=pad_to_use)

    x = avg_pool2d(x, kernel_size=kernel_size)
    y = avg_pool2d(y, kernel_size=kernel_size)

    x_lhm = rgb2lhm(x)
    y_lhm = rgb2lhm(y)

    kernels = torch.stack([prewitt_filter(),
                           prewitt_filter().transpose(1, 2)]).to(x)
    gm_x = gradient_map(x_lhm[:, :1], kernels)
    gm_y = gradient_map(y_lhm[:, :1], kernels)
    gm_avg = gradient_map((x_lhm[:, :1] + y_lhm[:, :1]) / 2., kernels)

    gs_x_y = similarity_map(gm_x, gm_y, c1)
    gs_x_average = similarity_map(gm_x, gm_avg, c2)
    gs_y_average = similarity_map(gm_y, gm_avg, c2)

    gs_total = gs_x_y + gs_x_average - gs_y_average

    cs_total = (2 *
                (x_lhm[:, 1:2] * y_lhm[:, 1:2] + x_lhm[:, 2:] * y_lhm[:, 2:]) +
                c3) / (x_lhm[:, 1:2]**2 + y_lhm[:, 1:2]**2 + x_lhm[:, 2:]**2 +
                       y_lhm[:, 2:]**2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]),
                          dim=-1)
    else:
        raise ValueError(
            f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(
        dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) -
             mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score**rho).mean(dim=(-1, -2))**(o / rho)).squeeze(1)
    if reduction == 'none':
        return score
    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Exemple #11
0
def fsim(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean',
         data_range: Union[int, float] = 1.0, chromatic: bool = True,
         scales: int = 4, orientations: int = 4, min_length: int = 6,
         mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2,
         k: float = 2.0) -> torch.Tensor:
    r"""Compute Feature Similarity Index Measure for a batch of images.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        chromatic: Flag to compute FSIMc, which also takes into account chromatic components
        scales: Number of wavelets used for computation of phase congruensy maps
        orientations: Number of filter orientations used for computation of phase congruensy maps
        min_length: Wavelength of smallest scale filter
        mult: Scaling factor between successive filters
        sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's
            transfer function in the frequency domain to the filter center frequency.
        delta_theta: Ratio of angular interval between filter orientations and the standard deviation
            of the angular Gaussian function used to construct filters in the frequency plane.
        k: No of standard deviations of the noise energy beyond the mean at which we set the noise
            threshold  point, below which phase congruency values get penalized.

    Returns:
        Index of similarity between two images. Usually in [0, 1] interval.
        Can be bigger than 1 for predicted :math:`x` images with higher contrast than the original ones.

    References:
        L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment,"
        IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730.
        https://ieeexplore.ieee.org/document/5705575

    Note:
        This implementation is based on the original MATLAB code.
        https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm

    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    # Apply average pooling
    kernel_size = max(1, round(min(x.shape[-2:]) / 256))
    x = torch.nn.functional.avg_pool2d(x, kernel_size)
    y = torch.nn.functional.avg_pool2d(y, kernel_size)

    num_channels = x.size(1)

    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)

        x_lum = x_yiq[:, : 1]
        y_lum = y_yiq[:, : 1]

        x_i = x_yiq[:, 1:2]
        y_i = y_yiq[:, 1:2]
        x_q = x_yiq[:, 2:]
        y_q = y_yiq[:, 2:]

    else:
        x_lum = x
        y_lum = y

    # Compute phase congruency maps
    pc_x = _phase_congruency(
        x_lum, scales=scales, orientations=orientations,
        min_length=min_length, mult=mult, sigma_f=sigma_f,
        delta_theta=delta_theta, k=k
    )
    pc_y = _phase_congruency(
        y_lum, scales=scales, orientations=orientations,
        min_length=min_length, mult=mult, sigma_f=sigma_f,
        delta_theta=delta_theta, k=k
    )

    # Gradient maps
    kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
    grad_map_x = gradient_map(x_lum, kernels)
    grad_map_y = gradient_map(y_lum, kernels)

    # Constants from the paper
    T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03

    # Compute FSIM
    PC = similarity_map(pc_x, pc_y, T1)
    GM = similarity_map(grad_map_x, grad_map_y, T2)
    pc_max = torch.where(pc_x > pc_y, pc_x, pc_y)
    score = GM * PC * pc_max

    if chromatic:
        assert num_channels == 3, "Chromatic component can be computed only for RGB images!"
        S_I = similarity_map(x_i, y_i, T3)
        S_Q = similarity_map(x_q, y_q, T4)
        score = score * torch.abs(S_I * S_Q) ** lmbda
        # Complex gradients will work in PyTorch 1.6.0
        # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)

    result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3])

    return _reduce(result, reduction)
Exemple #12
0
def mdsi(prediction: torch.Tensor, target: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
         c1: float = 140., c2: float = 55., c3: float = 550., combination: str = 'sum', alpha: float = 0.6,
         beta: float = 0.1, gamma: float = 0.2, rho: float = 1., q: float = 0.25, o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.

    Note:
        Both inputs are supposed to have RGB order in accordance with the original approach.
        Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
        channel 3 times.

    Args:
        prediction: Batch of predicted (distorted) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W),
        channels first.
        target: Batch of target (reference) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly

    Note:
        The ratio between constants is usually equal c3 = 4c1 = 10c2
    """
    _validate_input(input_tensors=(prediction, target), allow_5d=False)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    if prediction.size(1) == 1:
        prediction = prediction.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
        warnings.warn('The original MDSI supports only RGB images. The input images were converted to RGB by copying '
                      'the grey channel 3 times.')

    prediction = prediction * 255. / data_range
    target = target * 255. / data_range

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(prediction.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        prediction = pad(prediction, pad=pad_to_use)
        target = pad(target, pad=pad_to_use)

    prediction = avg_pool2d(prediction, kernel_size=kernel_size)
    target = avg_pool2d(target, kernel_size=kernel_size)

    prediction_lhm = rgb2lhm(prediction)
    target_lhm = rgb2lhm(target)

    kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(1, 2)]).to(prediction)
    gm_prediction = gradient_map(prediction_lhm[:, :1], kernels)
    gm_target = gradient_map(target_lhm[:, :1], kernels)
    gm_avg = gradient_map((prediction_lhm[:, :1] + target_lhm[:, :1]) / 2., kernels)

    gs_prediction_target = similarity_map(gm_prediction, gm_target, c1)
    gs_prediction_average = similarity_map(gm_prediction, gm_avg, c2)
    gs_target_average = similarity_map(gm_target, gm_avg, c2)

    gs_total = gs_prediction_target + gs_prediction_average - gs_target_average

    cs_total = (2 * (prediction_lhm[:, 1:2] * target_lhm[:, 1:2] +
                     prediction_lhm[:, 2:] * target_lhm[:, 2:]) + c3) / (prediction_lhm[:, 1:2] ** 2 +
                                                                         target_lhm[:, 1:2] ** 2 +
                                                                         prediction_lhm[:, 2:] ** 2 +
                                                                         target_lhm[:, 2:] ** 2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]), dim=-1)
    else:
        raise ValueError(f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) - mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score ** rho).mean(dim=(-1, -2)) ** (o / rho)).squeeze(1)
    if reduction == 'none':
        return score
    return {'mean': score.mean,
            'sum': score.sum}[reduction](dim=0)
Exemple #13
0
def srsim(x: torch.Tensor,
          y: torch.Tensor,
          reduction: str = 'mean',
          data_range: Union[int, float] = 1.0,
          chromatic: bool = False,
          scale: float = 0.25,
          kernel_size: int = 3,
          sigma: float = 3.8,
          gaussian_size: int = 10) -> torch.Tensor:
    r"""Compute Spectral Residual based Similarity for a batch of images.

    Args:
        x: Predicted images. Shape (H, W), (C, H, W) or (N, C, H, W).
        y: Target images. Shape (H, W), (C, H, W) or (N, C, H, W).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        chromatic: Flag to compute SR-SIMc, which also takes into account chromatic components
        scale: Resizing factor used in saliency map computation
        kernel_size: Kernel size of average blur filter used in saliency map computation
        sigma: Sigma of gaussian filter applied on saliency map
        gaussian_size: Size of gaussian filter applied on saliency map
    Returns:
        SR-SIM: Index of similarity between two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted images with higher contrast than the original ones.
    Note:
        This implementation is based on the original MATLAB code.
        https://sse.tongji.edu.cn/linzhang/IQA/SR-SIM/Files/SR_SIM.m

    """
    _validate_input(tensors=[x, y],
                    dim_range=(4, 4),
                    data_range=(0, data_range))

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = (x / float(data_range)) * 255
    y = (y / float(data_range)) * 255

    # Averaging image if the size is large enough
    ksize = max(1, round(min(x.size()[-2:]) / 256))
    padding = ksize // 2

    if padding:
        up_pad = (ksize - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = F.pad(x, pad=pad_to_use)
        y = F.pad(y, pad=pad_to_use)

    x = F.avg_pool2d(x, ksize)
    y = F.avg_pool2d(y, ksize)

    num_channels = x.size(1)

    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)

        x_lum = x_yiq[:, :1]
        y_lum = y_yiq[:, :1]

        x_i = x_yiq[:, 1:2]
        y_i = y_yiq[:, 1:2]
        x_q = x_yiq[:, 2:]
        y_q = y_yiq[:, 2:]

    else:
        if chromatic:
            raise ValueError(
                'Chromatic component can be computed only for RGB images!')
        x_lum = x
        y_lum = y

    # Compute phase congruency maps
    svrs_x = _spectral_residual_visual_saliency(x_lum,
                                                scale=scale,
                                                kernel_size=kernel_size,
                                                sigma=sigma,
                                                gaussian_size=gaussian_size)
    svrs_y = _spectral_residual_visual_saliency(y_lum,
                                                scale=scale,
                                                kernel_size=kernel_size,
                                                sigma=sigma,
                                                gaussian_size=gaussian_size)

    # Gradient maps
    kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
    grad_map_x = gradient_map(x_lum, kernels)
    grad_map_y = gradient_map(y_lum, kernels)

    # Constants from the paper
    C1, C2, alpha = 0.40, 225, 0.50

    # Compute SR-SIM
    SVRS = similarity_map(svrs_x, svrs_y, C1)
    GM = similarity_map(grad_map_x, grad_map_y, C2)
    svrs_max = torch.where(svrs_x > svrs_y, svrs_x, svrs_y)
    score = SVRS * (GM**alpha) * svrs_max

    if chromatic:
        # Constants from FSIM paper, use same method for color image
        T3, T4, lmbda = 200, 200, 0.03

        S_I = similarity_map(x_i, y_i, T3)
        S_Q = similarity_map(x_q, y_q, T4)
        score = score * torch.abs(S_I * S_Q)**lmbda
        # Complex gradients will work in PyTorch 1.6.0
        # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)

    eps = torch.finfo(score.dtype).eps
    result = score.sum(dim=[1, 2, 3]) / (svrs_max.sum(dim=[1, 2, 3]) + eps)

    return _reduce(result, reduction)