示例#1
0
def total_variation(x: torch.Tensor, reduction: str = 'mean', norm_type: str = 'l2') -> torch.Tensor:
    r"""Compute Total Variation metric

    Args:
        x: Tensor. Shape :math:`(N, C, H, W)`.
         reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic  or anisotropic.

    Returns:
        score : Total variation of a given tensor

    References:
        https://www.wikiwand.com/en/Total_variation_denoising
        https://remi.flamary.com/demos/proxtv.html
    """
    _validate_input([x, ], dim_range=(4, 4), data_range=(0, -1))

    if norm_type == 'l1':
        w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]), dim=[1, 2, 3])
        h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]), dim=[1, 2, 3])
        score = (h_variance + w_variance)
    elif norm_type == 'l2':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3])
        score = torch.sqrt(h_variance + w_variance)
    elif norm_type == 'l2_squared':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3])
        score = (h_variance + w_variance)
    else:
        raise ValueError("Incorrect norm type, should be one of {'l1', 'l2', 'l2_squared'}")

    return _reduce(score, reduction)
示例#2
0
文件: brisque.py 项目: happy20200/piq
def brisque(x: torch.Tensor,
            kernel_size: int = 7,
            kernel_sigma: float = 7 / 6,
            data_range: Union[int, float] = 1.,
            reduction: str = 'mean',
            interpolation: str = 'nearest') -> torch.Tensor:
    r"""Interface of BRISQUE index.
    Supports greyscale and colour images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
        interpolation: Interpolation to be used for scaling.

    Returns:
        Value of BRISQUE index.

    Note:
        The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation.
        Update the torch and torchvision to the latest versions.

    References:
        .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
           https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
    """
    if '1.5.0' in torch.__version__:
        warnings.warn(
            f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.'
            f'Update torch to the latest version to access full functionality of the BRIQSUE.'
            f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and'
            f'https://github.com/pytorch/pytorch/issues/38869.')

    assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
    _validate_input([
        x,
    ], dim_range=(4, 4), data_range=(0, data_range))

    x = x / data_range * 255

    if x.size(1) == 3:
        x = rgb2yiq(x)[:, :1]
    features = []
    num_of_scales = 2
    for _ in range(num_of_scales):
        features.append(_natural_scene_statistics(x, kernel_size,
                                                  kernel_sigma))
        x = F.interpolate(x,
                          size=(x.size(2) // 2, x.size(3) // 2),
                          mode=interpolation)

    features = torch.cat(features, dim=-1)
    scaled_features = _scale_features(features)
    score = _score_svr(scaled_features)

    return _reduce(score, reduction)
示例#3
0
def gmsd(
    x: torch.Tensor,
    y: torch.Tensor,
    reduction: str = 'mean',
    data_range: Union[int, float] = 1.,
    t: float = 170 / (255.**2)
) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation.

    Supports greyscale and colour images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        t: Constant from the reference paper numerical stability of similarity map.

    Returns:
        Gradient Magnitude Similarity Deviation between given tensors.

    References:
        Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013)
        https://arxiv.org/pdf/1308.3052.pdf
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    # Rescale
    x = x / float(data_range)
    y = y / float(data_range)

    num_channels = x.size(1)
    if num_channels == 3:
        x = rgb2yiq(x)[:, :1]
        y = rgb2yiq(y)[:, :1]
    up_pad = 0
    down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
    pad_to_use = [up_pad, down_pad, up_pad, down_pad]
    x = F.pad(x, pad=pad_to_use)
    y = F.pad(y, pad=pad_to_use)

    x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
    y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0)

    score = _gmsd(x=x, y=y, t=t)
    return _reduce(score, reduction)
示例#4
0
def fsim(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean',
         data_range: Union[int, float] = 1.0, chromatic: bool = True,
         scales: int = 4, orientations: int = 4, min_length: int = 6,
         mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2,
         k: float = 2.0) -> torch.Tensor:
    r"""Compute Feature Similarity Index Measure for a batch of images.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        chromatic: Flag to compute FSIMc, which also takes into account chromatic components
        scales: Number of wavelets used for computation of phase congruensy maps
        orientations: Number of filter orientations used for computation of phase congruensy maps
        min_length: Wavelength of smallest scale filter
        mult: Scaling factor between successive filters
        sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's
            transfer function in the frequency domain to the filter center frequency.
        delta_theta: Ratio of angular interval between filter orientations and the standard deviation
            of the angular Gaussian function used to construct filters in the frequency plane.
        k: No of standard deviations of the noise energy beyond the mean at which we set the noise
            threshold  point, below which phase congruency values get penalized.

    Returns:
        Index of similarity between two images. Usually in [0, 1] interval.
        Can be bigger than 1 for predicted :math:`x` images with higher contrast than the original ones.

    References:
        L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment,"
        IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730.
        https://ieeexplore.ieee.org/document/5705575

    Note:
        This implementation is based on the original MATLAB code.
        https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm

    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    # Apply average pooling
    kernel_size = max(1, round(min(x.shape[-2:]) / 256))
    x = torch.nn.functional.avg_pool2d(x, kernel_size)
    y = torch.nn.functional.avg_pool2d(y, kernel_size)

    num_channels = x.size(1)

    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)

        x_lum = x_yiq[:, : 1]
        y_lum = y_yiq[:, : 1]

        x_i = x_yiq[:, 1:2]
        y_i = y_yiq[:, 1:2]
        x_q = x_yiq[:, 2:]
        y_q = y_yiq[:, 2:]

    else:
        x_lum = x
        y_lum = y

    # Compute phase congruency maps
    pc_x = _phase_congruency(
        x_lum, scales=scales, orientations=orientations,
        min_length=min_length, mult=mult, sigma_f=sigma_f,
        delta_theta=delta_theta, k=k
    )
    pc_y = _phase_congruency(
        y_lum, scales=scales, orientations=orientations,
        min_length=min_length, mult=mult, sigma_f=sigma_f,
        delta_theta=delta_theta, k=k
    )

    # Gradient maps
    kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
    grad_map_x = gradient_map(x_lum, kernels)
    grad_map_y = gradient_map(y_lum, kernels)

    # Constants from the paper
    T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03

    # Compute FSIM
    PC = similarity_map(pc_x, pc_y, T1)
    GM = similarity_map(grad_map_x, grad_map_y, T2)
    pc_max = torch.where(pc_x > pc_y, pc_x, pc_y)
    score = GM * PC * pc_max

    if chromatic:
        assert num_channels == 3, "Chromatic component can be computed only for RGB images!"
        S_I = similarity_map(x_i, y_i, T3)
        S_Q = similarity_map(x_q, y_q, T4)
        score = score * torch.abs(S_I * S_Q) ** lmbda
        # Complex gradients will work in PyTorch 1.6.0
        # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)

    result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3])

    return _reduce(result, reduction)
示例#5
0
def information_weighted_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1.,
                              kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
                              parent: bool = True, blk_size: int = 3, sigma_nsq: float = 0.4,
                              scale_weights: Optional[torch.Tensor] = None,
                              reduction: str = 'mean') -> torch.Tensor:
    r"""Interface of Information Content Weighted Structural Similarity (IW-SSIM) index.
    Inputs supposed to be in range ``[0, data_range]``.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        data_range: Maximum value range of images (usually 1.0 or 255).
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution for sliding window used in comparison.
        k1: Algorithm parameter, K1 (small constant).
        k2: Algorithm parameter, K2 (small constant).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        parent: Flag to control dependency on previous layer of pyramid.
        blk_size: The side-length of the sliding window used in comparison for information content.
        sigma_nsq: Parameter of visual distortion model.
        scale_weights: Weights for scaling.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``

    Returns:
        Value of Information Content Weighted Structural Similarity (IW-SSIM) index.

    References:
        Wang, Zhou, and Qiang Li..
        Information content weighting for perceptual image quality assessment.
        IEEE Transactions on image processing 20.5 (2011): 1185-1198.
        https://ece.uwaterloo.ca/~z70wang/publications/IWSSIM.pdf DOI:`10.1109/TIP.2010.2092435`

    Note:
        Lack of content in target image could lead to RuntimeError due to singular information content matrix,
        which cannot be inverted.
    """
    assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'

    _validate_input(tensors=[x, y], dim_range=(4, 4), data_range=(0., data_range))

    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    if x.size(1) == 3:
        x = rgb2yiq(x)[:, :1]
        y = rgb2yiq(y)[:, :1]

    if scale_weights is None:
        scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=x.dtype, device=x.device)
    scale_weights = scale_weights / scale_weights.sum()
    if scale_weights.size(0) != scale_weights.numel():
        raise ValueError(f'Expected a vector of weights, got {scale_weights.dim()}D tensor')

    levels = scale_weights.size(0)

    min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1
    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.')

    blur_pad = math.ceil((kernel_size - 1) / 2)  # Ceil
    iw_pad = blur_pad - math.floor((blk_size - 1) / 2)  # floor
    gauss_kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(x)

    # Size of the kernel size to build Laplacian pyramid
    pyramid_kernel_size = 5
    bin_filter = binomial_filter1d(kernel_size=pyramid_kernel_size).to(x) * 2 ** 0.5

    lo_x, x_diff_old = _pyr_step(x, bin_filter)
    lo_y, y_diff_old = _pyr_step(y, bin_filter)

    x = lo_x
    y = lo_y
    wmcs = []

    for i in range(levels):
        if i < levels - 2:
            lo_x, x_diff = _pyr_step(x, bin_filter)
            lo_y, y_diff = _pyr_step(y, bin_filter)
            x = lo_x
            y = lo_y

        else:
            x_diff = x
            y_diff = y

        ssim_map, cs_map = _ssim_per_channel(x=x_diff_old, y=y_diff_old, kernel=gauss_kernel, data_range=255,
                                             k1=k1, k2=k2)

        if parent and i < levels - 2:
            iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=y_diff, kernel_size=blk_size,
                                          sigma_nsq=sigma_nsq)

            iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad]

        elif i == levels - 1:
            iw_map = torch.ones_like(cs_map)
            cs_map = ssim_map

        else:
            iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=None, kernel_size=blk_size,
                                          sigma_nsq=sigma_nsq)
            iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad]

        wmcs.append(torch.sum(cs_map * iw_map, dim=(-2, -1)) / torch.sum(iw_map, dim=(-2, -1)))

        x_diff_old = x_diff
        y_diff_old = y_diff

    wmcs = torch.stack(wmcs, dim=0).abs()

    score = torch.prod((wmcs ** scale_weights.view(-1, 1, 1)), dim=0)[:, 0]

    return _reduce(x=score, reduction=reduction)
示例#6
0
def multi_scale_gmsd(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
                     scale_weights: Optional[torch.Tensor] = None,
                     chromatic: bool = False, alpha: float = 0.5, beta1: float = 0.01, beta2: float = 0.32,
                     beta3: float = 15., t: float = 170) -> torch.Tensor:
    r"""Computation of Multi scale GMSD.

    Supports greyscale and colour images with RGB channel order.
    The height and width should be at least 2 ** scales + 1.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        scale_weights: Weights for different scales. Can contain any number of floating point values.
        chromatic: Flag to use MS-GMSDc algorithm from paper.
            It also evaluates chromatic components of the image. Default: True
        alpha: Masking coefficient. See [1] for details.
        beta1: Algorithm parameter. Weight of chromatic component in the loss.
        beta2: Algorithm parameter. Small constant, see [1].
        beta3: Algorithm parameter. Small constant, see [1].
        t: Constant from the reference paper numerical stability of similarity map

    Returns:
        Value of MS-GMSD. 0 <= GMSD loss <= 1.
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))
    
    # Rescale
    x = x / data_range * 255
    y = y / data_range * 255
    
    # Values from the paper
    if scale_weights is None:
        scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019], device=x.device)
    else:
        # Normalize scale weights
        scale_weights = (scale_weights / scale_weights.sum()).to(x)

    # Check that input is big enough
    num_scales = scale_weights.size(0)
    min_size = 2 ** num_scales + 1

    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.')

    num_channels = x.size(1)
    if num_channels == 3:
        x = rgb2yiq(x)
        y = rgb2yiq(y)

    ms_gmds = []
    for scale in range(num_scales):
        if scale > 0:

            # Average by 2x2 filter and downsample
            up_pad = 0
            down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
            pad_to_use = [up_pad, down_pad, up_pad, down_pad]
            x = F.pad(x, pad=pad_to_use)
            y = F.pad(y, pad=pad_to_use)
            x = F.avg_pool2d(x, kernel_size=2, padding=0)
            y = F.avg_pool2d(y, kernel_size=2, padding=0)

        score = _gmsd(x[:, :1], y[:, :1], t=t, alpha=alpha)
        ms_gmds.append(score)

    # Stack results in different scales and multiply by weight
    ms_gmds_val = scale_weights.view(1, num_scales) * (torch.stack(ms_gmds, dim=1) ** 2)

    # Sum and take sqrt per-image
    ms_gmds_val = torch.sqrt(torch.sum(ms_gmds_val, dim=1))

    # Shape: (batch_size, )
    score = ms_gmds_val

    if chromatic:
        assert x.size(1) == 3, "Chromatic component can be computed only for RGB images!"

        x_iq = x[:, 1:]
        y_iq = y[:, 1:]

        rmse_iq = torch.sqrt(torch.mean((x_iq - y_iq) ** 2, dim=[2, 3]))
        rmse_chrome = torch.sqrt(torch.sum(rmse_iq ** 2, dim=1))
        gamma = 2 / (1 + beta2 * torch.exp(-beta3 * ms_gmds_val)) - 1

        score = gamma * ms_gmds_val + (1 - gamma) * beta1 * rmse_chrome

    return _reduce(score, reduction)
示例#7
0
def dss(x: torch.Tensor,
        y: torch.Tensor,
        reduction: str = 'mean',
        data_range: Union[int, float] = 1.0,
        dct_size: int = 8,
        sigma_weight: float = 1.55,
        kernel_size: int = 3,
        sigma_similarity: float = 1.5,
        percentile: float = 0.05) -> torch.Tensor:
    r"""Compute DCT Subband Similarity index for a batch of images.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        dct_size: Size of blocks in 2D Discrete Cosine Transform. DCT sizes must be in (0, input size].
        sigma_weight: STD of gaussian that determines the proportion of weight given to low freq and high freq.
            Default: 1.55
        kernel_size: Size of gaussian kernel for computing subband similarity. Kernels size must be in (0, input size].
            Default: 3
        sigma_similarity: STD of gaussian kernel for computing subband similarity. Default: 1.55
        percentile: % in (0, 1] of the worst similarity scores which should be kept. Default: 0.05
    Returns:
        DSS: Index of similarity between two images. In [0, 1] interval.
    Note:
        This implementation is based on the original MATLAB code (see header).
        Image will be scaled to [0, 255] because all constants are computed for this range.
        Make sure you know what you are doing when changing default coefficient values.
    """
    if sigma_weight == 0 or sigma_similarity == 0:
        raise ValueError(
            f'Gaussian sigmas must not be 0, got sigma_weight: {sigma_weight} and '
            f'sigma_similarity: {sigma_similarity}')

    if percentile <= 0 or percentile > 1:
        raise ValueError(f'Percentile must be in (0,1], got {percentile}')

    _validate_input(tensors=[x, y], dim_range=(4, 4))

    for size in (dct_size, kernel_size):
        if size <= 0 or size > min(x.size(-1), x.size(-2)):
            raise ValueError(
                'DCT and kernels sizes must be included in (0, input size]')

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = (x / float(data_range)) * 255
    y = (y / float(data_range)) * 255

    num_channels = x.size(1)

    # Use luminance channel in case of RGB images (Y from YIQ or YCrCb)
    if num_channels == 3:
        x_lum = rgb2yiq(x)[:, :1]
        y_lum = rgb2yiq(y)[:, :1]
    else:
        x_lum = x
        y_lum = y

    # Crop images size to the closest multiplication of `dct_size`
    rows, cols = x_lum.size()[-2:]
    rows = dct_size * (rows // dct_size)
    cols = dct_size * (cols // dct_size)
    x_lum = x_lum[:, :, 0:rows, 0:cols]
    y_lum = y_lum[:, :, 0:rows, 0:cols]

    # Channel decomposition for both images by `dct_size`x`dct_size` 2D DCT
    dct_x = _dct_decomp(x_lum, dct_size)
    dct_y = _dct_decomp(y_lum, dct_size)

    # Create a Gaussian window that will be used to weight subbands scores
    coords = torch.arange(1, dct_size + 1).to(x)
    weight = (coords - 0.5)**2
    weight = (-(weight.unsqueeze(0) + weight.unsqueeze(1)) /
              (2 * sigma_weight**2)).exp()

    # Compute similarity between each subband in img1 and img2
    subband_sim_matrix = torch.zeros((x.size(0), dct_size, dct_size),
                                     device=x.device)
    threshold = 1e-2
    for m in range(dct_size):
        for n in range(dct_size):
            first_term = (m == 0 and n == 0)  # boolean

            # Skip subbands with very small weight
            if weight[m, n] < threshold:
                weight[m, n] = 0
                continue

            subband_sim_matrix[:, m, n] = _subband_similarity(
                dct_x[:, :, m::dct_size, n::dct_size], dct_y[:, :, m::dct_size,
                                                             n::dct_size],
                first_term, kernel_size, sigma_similarity, percentile)

    # Weight subbands similarity scores
    eps = torch.finfo(weight.dtype).eps
    similarity_scores = torch.sum(subband_sim_matrix *
                                  (weight / (torch.sum(weight)) + eps),
                                  dim=[1, 2])
    dss_val = _reduce(similarity_scores, reduction)
    return dss_val
示例#8
0
def srsim(x: torch.Tensor,
          y: torch.Tensor,
          reduction: str = 'mean',
          data_range: Union[int, float] = 1.0,
          chromatic: bool = False,
          scale: float = 0.25,
          kernel_size: int = 3,
          sigma: float = 3.8,
          gaussian_size: int = 10) -> torch.Tensor:
    r"""Compute Spectral Residual based Similarity for a batch of images.

    Args:
        x: Predicted images. Shape (H, W), (C, H, W) or (N, C, H, W).
        y: Target images. Shape (H, W), (C, H, W) or (N, C, H, W).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        chromatic: Flag to compute SR-SIMc, which also takes into account chromatic components
        scale: Resizing factor used in saliency map computation
        kernel_size: Kernel size of average blur filter used in saliency map computation
        sigma: Sigma of gaussian filter applied on saliency map
        gaussian_size: Size of gaussian filter applied on saliency map
    Returns:
        SR-SIM: Index of similarity between two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted images with higher contrast than the original ones.
    Note:
        This implementation is based on the original MATLAB code.
        https://sse.tongji.edu.cn/linzhang/IQA/SR-SIM/Files/SR_SIM.m

    """
    _validate_input(tensors=[x, y],
                    dim_range=(4, 4),
                    data_range=(0, data_range))

    # Rescale to [0, 255] range, because all constant are calculated for this factor
    x = (x / float(data_range)) * 255
    y = (y / float(data_range)) * 255

    # Averaging image if the size is large enough
    ksize = max(1, round(min(x.size()[-2:]) / 256))
    padding = ksize // 2

    if padding:
        up_pad = (ksize - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = F.pad(x, pad=pad_to_use)
        y = F.pad(y, pad=pad_to_use)

    x = F.avg_pool2d(x, ksize)
    y = F.avg_pool2d(y, ksize)

    num_channels = x.size(1)

    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)

        x_lum = x_yiq[:, :1]
        y_lum = y_yiq[:, :1]

        x_i = x_yiq[:, 1:2]
        y_i = y_yiq[:, 1:2]
        x_q = x_yiq[:, 2:]
        y_q = y_yiq[:, 2:]

    else:
        if chromatic:
            raise ValueError(
                'Chromatic component can be computed only for RGB images!')
        x_lum = x
        y_lum = y

    # Compute phase congruency maps
    svrs_x = _spectral_residual_visual_saliency(x_lum,
                                                scale=scale,
                                                kernel_size=kernel_size,
                                                sigma=sigma,
                                                gaussian_size=gaussian_size)
    svrs_y = _spectral_residual_visual_saliency(y_lum,
                                                scale=scale,
                                                kernel_size=kernel_size,
                                                sigma=sigma,
                                                gaussian_size=gaussian_size)

    # Gradient maps
    kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)])
    grad_map_x = gradient_map(x_lum, kernels)
    grad_map_y = gradient_map(y_lum, kernels)

    # Constants from the paper
    C1, C2, alpha = 0.40, 225, 0.50

    # Compute SR-SIM
    SVRS = similarity_map(svrs_x, svrs_y, C1)
    GM = similarity_map(grad_map_x, grad_map_y, C2)
    svrs_max = torch.where(svrs_x > svrs_y, svrs_x, svrs_y)
    score = SVRS * (GM**alpha) * svrs_max

    if chromatic:
        # Constants from FSIM paper, use same method for color image
        T3, T4, lmbda = 200, 200, 0.03

        S_I = similarity_map(x_i, y_i, T3)
        S_Q = similarity_map(x_q, y_q, T4)
        score = score * torch.abs(S_I * S_Q)**lmbda
        # Complex gradients will work in PyTorch 1.6.0
        # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda)

    eps = torch.finfo(score.dtype).eps
    result = score.sum(dim=[1, 2, 3]) / (svrs_max.sum(dim=[1, 2, 3]) + eps)

    return _reduce(result, reduction)