def total_variation(x: torch.Tensor, reduction: str = 'mean', norm_type: str = 'l2') -> torch.Tensor: r"""Compute Total Variation metric Args: x: Tensor. Shape :math:`(N, C, H, W)`. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic or anisotropic. Returns: score : Total variation of a given tensor References: https://www.wikiwand.com/en/Total_variation_denoising https://remi.flamary.com/demos/proxtv.html """ _validate_input([x, ], dim_range=(4, 4), data_range=(0, -1)) if norm_type == 'l1': w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]), dim=[1, 2, 3]) h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]), dim=[1, 2, 3]) score = (h_variance + w_variance) elif norm_type == 'l2': w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3]) h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3]) score = torch.sqrt(h_variance + w_variance) elif norm_type == 'l2_squared': w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3]) h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3]) score = (h_variance + w_variance) else: raise ValueError("Incorrect norm type, should be one of {'l1', 'l2', 'l2_squared'}") return _reduce(score, reduction)
def brisque(x: torch.Tensor, kernel_size: int = 7, kernel_sigma: float = 7 / 6, data_range: Union[int, float] = 1., reduction: str = 'mean', interpolation: str = 'nearest') -> torch.Tensor: r"""Interface of BRISQUE index. Supports greyscale and colour images with RGB channel order. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` interpolation: Interpolation to be used for scaling. Returns: Value of BRISQUE index. Note: The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation. Update the torch and torchvision to the latest versions. References: .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf """ if '1.5.0' in torch.__version__: warnings.warn( f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.' f'Update torch to the latest version to access full functionality of the BRIQSUE.' f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and' f'https://github.com/pytorch/pytorch/issues/38869.') assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' _validate_input([ x, ], dim_range=(4, 4), data_range=(0, data_range)) x = x / data_range * 255 if x.size(1) == 3: x = rgb2yiq(x)[:, :1] features = [] num_of_scales = 2 for _ in range(num_of_scales): features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma)) x = F.interpolate(x, size=(x.size(2) // 2, x.size(3) // 2), mode=interpolation) features = torch.cat(features, dim=-1) scaled_features = _scale_features(features) score = _score_svr(scaled_features) return _reduce(score, reduction)
def gmsd( x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean', data_range: Union[int, float] = 1., t: float = 170 / (255.**2) ) -> torch.Tensor: r"""Compute Gradient Magnitude Similarity Deviation. Supports greyscale and colour images with RGB channel order. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` data_range: Maximum value range of images (usually 1.0 or 255). t: Constant from the reference paper numerical stability of similarity map. Returns: Gradient Magnitude Similarity Deviation between given tensors. References: Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013) https://arxiv.org/pdf/1308.3052.pdf """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range)) # Rescale x = x / float(data_range) y = y / float(data_range) num_channels = x.size(1) if num_channels == 3: x = rgb2yiq(x)[:, :1] y = rgb2yiq(y)[:, :1] up_pad = 0 down_pad = max(x.shape[2] % 2, x.shape[3] % 2) pad_to_use = [up_pad, down_pad, up_pad, down_pad] x = F.pad(x, pad=pad_to_use) y = F.pad(y, pad=pad_to_use) x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0) score = _gmsd(x=x, y=y, t=t) return _reduce(score, reduction)
def fsim(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean', data_range: Union[int, float] = 1.0, chromatic: bool = True, scales: int = 4, orientations: int = 4, min_length: int = 6, mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2, k: float = 2.0) -> torch.Tensor: r"""Compute Feature Similarity Index Measure for a batch of images. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` data_range: Maximum value range of images (usually 1.0 or 255). chromatic: Flag to compute FSIMc, which also takes into account chromatic components scales: Number of wavelets used for computation of phase congruensy maps orientations: Number of filter orientations used for computation of phase congruensy maps min_length: Wavelength of smallest scale filter mult: Scaling factor between successive filters sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's transfer function in the frequency domain to the filter center frequency. delta_theta: Ratio of angular interval between filter orientations and the standard deviation of the angular Gaussian function used to construct filters in the frequency plane. k: No of standard deviations of the noise energy beyond the mean at which we set the noise threshold point, below which phase congruency values get penalized. Returns: Index of similarity between two images. Usually in [0, 1] interval. Can be bigger than 1 for predicted :math:`x` images with higher contrast than the original ones. References: L. Zhang, L. Zhang, X. Mou and D. Zhang, "FSIM: A Feature Similarity Index for Image Quality Assessment," IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, Aug. 2011, doi: 10.1109/TIP.2011.2109730. https://ieeexplore.ieee.org/document/5705575 Note: This implementation is based on the original MATLAB code. https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range)) # Rescale to [0, 255] range, because all constant are calculated for this factor x = x / float(data_range) * 255 y = y / float(data_range) * 255 # Apply average pooling kernel_size = max(1, round(min(x.shape[-2:]) / 256)) x = torch.nn.functional.avg_pool2d(x, kernel_size) y = torch.nn.functional.avg_pool2d(y, kernel_size) num_channels = x.size(1) # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ if num_channels == 3: x_yiq = rgb2yiq(x) y_yiq = rgb2yiq(y) x_lum = x_yiq[:, : 1] y_lum = y_yiq[:, : 1] x_i = x_yiq[:, 1:2] y_i = y_yiq[:, 1:2] x_q = x_yiq[:, 2:] y_q = y_yiq[:, 2:] else: x_lum = x y_lum = y # Compute phase congruency maps pc_x = _phase_congruency( x_lum, scales=scales, orientations=orientations, min_length=min_length, mult=mult, sigma_f=sigma_f, delta_theta=delta_theta, k=k ) pc_y = _phase_congruency( y_lum, scales=scales, orientations=orientations, min_length=min_length, mult=mult, sigma_f=sigma_f, delta_theta=delta_theta, k=k ) # Gradient maps kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)]) grad_map_x = gradient_map(x_lum, kernels) grad_map_y = gradient_map(y_lum, kernels) # Constants from the paper T1, T2, T3, T4, lmbda = 0.85, 160, 200, 200, 0.03 # Compute FSIM PC = similarity_map(pc_x, pc_y, T1) GM = similarity_map(grad_map_x, grad_map_y, T2) pc_max = torch.where(pc_x > pc_y, pc_x, pc_y) score = GM * PC * pc_max if chromatic: assert num_channels == 3, "Chromatic component can be computed only for RGB images!" S_I = similarity_map(x_i, y_i, T3) S_Q = similarity_map(x_q, y_q, T4) score = score * torch.abs(S_I * S_Q) ** lmbda # Complex gradients will work in PyTorch 1.6.0 # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda) result = score.sum(dim=[1, 2, 3]) / pc_max.sum(dim=[1, 2, 3]) return _reduce(result, reduction)
def information_weighted_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1., kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, parent: bool = True, blk_size: int = 3, sigma_nsq: float = 0.4, scale_weights: Optional[torch.Tensor] = None, reduction: str = 'mean') -> torch.Tensor: r"""Interface of Information Content Weighted Structural Similarity (IW-SSIM) index. Inputs supposed to be in range ``[0, data_range]``. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution for sliding window used in comparison. k1: Algorithm parameter, K1 (small constant). k2: Algorithm parameter, K2 (small constant). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. parent: Flag to control dependency on previous layer of pyramid. blk_size: The side-length of the sliding window used in comparison for information content. sigma_nsq: Parameter of visual distortion model. scale_weights: Weights for scaling. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` Returns: Value of Information Content Weighted Structural Similarity (IW-SSIM) index. References: Wang, Zhou, and Qiang Li.. Information content weighting for perceptual image quality assessment. IEEE Transactions on image processing 20.5 (2011): 1185-1198. https://ece.uwaterloo.ca/~z70wang/publications/IWSSIM.pdf DOI:`10.1109/TIP.2010.2092435` Note: Lack of content in target image could lead to RuntimeError due to singular information content matrix, which cannot be inverted. """ assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' _validate_input(tensors=[x, y], dim_range=(4, 4), data_range=(0., data_range)) x = x / float(data_range) * 255 y = y / float(data_range) * 255 if x.size(1) == 3: x = rgb2yiq(x)[:, :1] y = rgb2yiq(y)[:, :1] if scale_weights is None: scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=x.dtype, device=x.device) scale_weights = scale_weights / scale_weights.sum() if scale_weights.size(0) != scale_weights.numel(): raise ValueError(f'Expected a vector of weights, got {scale_weights.dim()}D tensor') levels = scale_weights.size(0) min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1 if x.size(-1) < min_size or x.size(-2) < min_size: raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.') blur_pad = math.ceil((kernel_size - 1) / 2) # Ceil iw_pad = blur_pad - math.floor((blk_size - 1) / 2) # floor gauss_kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(x) # Size of the kernel size to build Laplacian pyramid pyramid_kernel_size = 5 bin_filter = binomial_filter1d(kernel_size=pyramid_kernel_size).to(x) * 2 ** 0.5 lo_x, x_diff_old = _pyr_step(x, bin_filter) lo_y, y_diff_old = _pyr_step(y, bin_filter) x = lo_x y = lo_y wmcs = [] for i in range(levels): if i < levels - 2: lo_x, x_diff = _pyr_step(x, bin_filter) lo_y, y_diff = _pyr_step(y, bin_filter) x = lo_x y = lo_y else: x_diff = x y_diff = y ssim_map, cs_map = _ssim_per_channel(x=x_diff_old, y=y_diff_old, kernel=gauss_kernel, data_range=255, k1=k1, k2=k2) if parent and i < levels - 2: iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=y_diff, kernel_size=blk_size, sigma_nsq=sigma_nsq) iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad] elif i == levels - 1: iw_map = torch.ones_like(cs_map) cs_map = ssim_map else: iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=None, kernel_size=blk_size, sigma_nsq=sigma_nsq) iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad] wmcs.append(torch.sum(cs_map * iw_map, dim=(-2, -1)) / torch.sum(iw_map, dim=(-2, -1))) x_diff_old = x_diff y_diff_old = y_diff wmcs = torch.stack(wmcs, dim=0).abs() score = torch.prod((wmcs ** scale_weights.view(-1, 1, 1)), dim=0)[:, 0] return _reduce(x=score, reduction=reduction)
def multi_scale_gmsd(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean', scale_weights: Optional[torch.Tensor] = None, chromatic: bool = False, alpha: float = 0.5, beta1: float = 0.01, beta2: float = 0.32, beta3: float = 15., t: float = 170) -> torch.Tensor: r"""Computation of Multi scale GMSD. Supports greyscale and colour images with RGB channel order. The height and width should be at least 2 ** scales + 1. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` scale_weights: Weights for different scales. Can contain any number of floating point values. chromatic: Flag to use MS-GMSDc algorithm from paper. It also evaluates chromatic components of the image. Default: True alpha: Masking coefficient. See [1] for details. beta1: Algorithm parameter. Weight of chromatic component in the loss. beta2: Algorithm parameter. Small constant, see [1]. beta3: Algorithm parameter. Small constant, see [1]. t: Constant from the reference paper numerical stability of similarity map Returns: Value of MS-GMSD. 0 <= GMSD loss <= 1. """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range)) # Rescale x = x / data_range * 255 y = y / data_range * 255 # Values from the paper if scale_weights is None: scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019], device=x.device) else: # Normalize scale weights scale_weights = (scale_weights / scale_weights.sum()).to(x) # Check that input is big enough num_scales = scale_weights.size(0) min_size = 2 ** num_scales + 1 if x.size(-1) < min_size or x.size(-2) < min_size: raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.') num_channels = x.size(1) if num_channels == 3: x = rgb2yiq(x) y = rgb2yiq(y) ms_gmds = [] for scale in range(num_scales): if scale > 0: # Average by 2x2 filter and downsample up_pad = 0 down_pad = max(x.shape[2] % 2, x.shape[3] % 2) pad_to_use = [up_pad, down_pad, up_pad, down_pad] x = F.pad(x, pad=pad_to_use) y = F.pad(y, pad=pad_to_use) x = F.avg_pool2d(x, kernel_size=2, padding=0) y = F.avg_pool2d(y, kernel_size=2, padding=0) score = _gmsd(x[:, :1], y[:, :1], t=t, alpha=alpha) ms_gmds.append(score) # Stack results in different scales and multiply by weight ms_gmds_val = scale_weights.view(1, num_scales) * (torch.stack(ms_gmds, dim=1) ** 2) # Sum and take sqrt per-image ms_gmds_val = torch.sqrt(torch.sum(ms_gmds_val, dim=1)) # Shape: (batch_size, ) score = ms_gmds_val if chromatic: assert x.size(1) == 3, "Chromatic component can be computed only for RGB images!" x_iq = x[:, 1:] y_iq = y[:, 1:] rmse_iq = torch.sqrt(torch.mean((x_iq - y_iq) ** 2, dim=[2, 3])) rmse_chrome = torch.sqrt(torch.sum(rmse_iq ** 2, dim=1)) gamma = 2 / (1 + beta2 * torch.exp(-beta3 * ms_gmds_val)) - 1 score = gamma * ms_gmds_val + (1 - gamma) * beta1 * rmse_chrome return _reduce(score, reduction)
def dss(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean', data_range: Union[int, float] = 1.0, dct_size: int = 8, sigma_weight: float = 1.55, kernel_size: int = 3, sigma_similarity: float = 1.5, percentile: float = 0.05) -> torch.Tensor: r"""Compute DCT Subband Similarity index for a batch of images. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` data_range: Maximum value range of images (usually 1.0 or 255). dct_size: Size of blocks in 2D Discrete Cosine Transform. DCT sizes must be in (0, input size]. sigma_weight: STD of gaussian that determines the proportion of weight given to low freq and high freq. Default: 1.55 kernel_size: Size of gaussian kernel for computing subband similarity. Kernels size must be in (0, input size]. Default: 3 sigma_similarity: STD of gaussian kernel for computing subband similarity. Default: 1.55 percentile: % in (0, 1] of the worst similarity scores which should be kept. Default: 0.05 Returns: DSS: Index of similarity between two images. In [0, 1] interval. Note: This implementation is based on the original MATLAB code (see header). Image will be scaled to [0, 255] because all constants are computed for this range. Make sure you know what you are doing when changing default coefficient values. """ if sigma_weight == 0 or sigma_similarity == 0: raise ValueError( f'Gaussian sigmas must not be 0, got sigma_weight: {sigma_weight} and ' f'sigma_similarity: {sigma_similarity}') if percentile <= 0 or percentile > 1: raise ValueError(f'Percentile must be in (0,1], got {percentile}') _validate_input(tensors=[x, y], dim_range=(4, 4)) for size in (dct_size, kernel_size): if size <= 0 or size > min(x.size(-1), x.size(-2)): raise ValueError( 'DCT and kernels sizes must be included in (0, input size]') # Rescale to [0, 255] range, because all constant are calculated for this factor x = (x / float(data_range)) * 255 y = (y / float(data_range)) * 255 num_channels = x.size(1) # Use luminance channel in case of RGB images (Y from YIQ or YCrCb) if num_channels == 3: x_lum = rgb2yiq(x)[:, :1] y_lum = rgb2yiq(y)[:, :1] else: x_lum = x y_lum = y # Crop images size to the closest multiplication of `dct_size` rows, cols = x_lum.size()[-2:] rows = dct_size * (rows // dct_size) cols = dct_size * (cols // dct_size) x_lum = x_lum[:, :, 0:rows, 0:cols] y_lum = y_lum[:, :, 0:rows, 0:cols] # Channel decomposition for both images by `dct_size`x`dct_size` 2D DCT dct_x = _dct_decomp(x_lum, dct_size) dct_y = _dct_decomp(y_lum, dct_size) # Create a Gaussian window that will be used to weight subbands scores coords = torch.arange(1, dct_size + 1).to(x) weight = (coords - 0.5)**2 weight = (-(weight.unsqueeze(0) + weight.unsqueeze(1)) / (2 * sigma_weight**2)).exp() # Compute similarity between each subband in img1 and img2 subband_sim_matrix = torch.zeros((x.size(0), dct_size, dct_size), device=x.device) threshold = 1e-2 for m in range(dct_size): for n in range(dct_size): first_term = (m == 0 and n == 0) # boolean # Skip subbands with very small weight if weight[m, n] < threshold: weight[m, n] = 0 continue subband_sim_matrix[:, m, n] = _subband_similarity( dct_x[:, :, m::dct_size, n::dct_size], dct_y[:, :, m::dct_size, n::dct_size], first_term, kernel_size, sigma_similarity, percentile) # Weight subbands similarity scores eps = torch.finfo(weight.dtype).eps similarity_scores = torch.sum(subband_sim_matrix * (weight / (torch.sum(weight)) + eps), dim=[1, 2]) dss_val = _reduce(similarity_scores, reduction) return dss_val
def srsim(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean', data_range: Union[int, float] = 1.0, chromatic: bool = False, scale: float = 0.25, kernel_size: int = 3, sigma: float = 3.8, gaussian_size: int = 10) -> torch.Tensor: r"""Compute Spectral Residual based Similarity for a batch of images. Args: x: Predicted images. Shape (H, W), (C, H, W) or (N, C, H, W). y: Target images. Shape (H, W), (C, H, W) or (N, C, H, W). reduction: Reduction over samples in batch: "mean"|"sum"|"none" data_range: Value range of input images (usually 1.0 or 255). Default: 1.0 chromatic: Flag to compute SR-SIMc, which also takes into account chromatic components scale: Resizing factor used in saliency map computation kernel_size: Kernel size of average blur filter used in saliency map computation sigma: Sigma of gaussian filter applied on saliency map gaussian_size: Size of gaussian filter applied on saliency map Returns: SR-SIM: Index of similarity between two images. Usually in [0, 1] interval. Can be bigger than 1 for predicted images with higher contrast than the original ones. Note: This implementation is based on the original MATLAB code. https://sse.tongji.edu.cn/linzhang/IQA/SR-SIM/Files/SR_SIM.m """ _validate_input(tensors=[x, y], dim_range=(4, 4), data_range=(0, data_range)) # Rescale to [0, 255] range, because all constant are calculated for this factor x = (x / float(data_range)) * 255 y = (y / float(data_range)) * 255 # Averaging image if the size is large enough ksize = max(1, round(min(x.size()[-2:]) / 256)) padding = ksize // 2 if padding: up_pad = (ksize - 1) // 2 down_pad = padding pad_to_use = [up_pad, down_pad, up_pad, down_pad] x = F.pad(x, pad=pad_to_use) y = F.pad(y, pad=pad_to_use) x = F.avg_pool2d(x, ksize) y = F.avg_pool2d(y, ksize) num_channels = x.size(1) # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ if num_channels == 3: x_yiq = rgb2yiq(x) y_yiq = rgb2yiq(y) x_lum = x_yiq[:, :1] y_lum = y_yiq[:, :1] x_i = x_yiq[:, 1:2] y_i = y_yiq[:, 1:2] x_q = x_yiq[:, 2:] y_q = y_yiq[:, 2:] else: if chromatic: raise ValueError( 'Chromatic component can be computed only for RGB images!') x_lum = x y_lum = y # Compute phase congruency maps svrs_x = _spectral_residual_visual_saliency(x_lum, scale=scale, kernel_size=kernel_size, sigma=sigma, gaussian_size=gaussian_size) svrs_y = _spectral_residual_visual_saliency(y_lum, scale=scale, kernel_size=kernel_size, sigma=sigma, gaussian_size=gaussian_size) # Gradient maps kernels = torch.stack([scharr_filter(), scharr_filter().transpose(-1, -2)]) grad_map_x = gradient_map(x_lum, kernels) grad_map_y = gradient_map(y_lum, kernels) # Constants from the paper C1, C2, alpha = 0.40, 225, 0.50 # Compute SR-SIM SVRS = similarity_map(svrs_x, svrs_y, C1) GM = similarity_map(grad_map_x, grad_map_y, C2) svrs_max = torch.where(svrs_x > svrs_y, svrs_x, svrs_y) score = SVRS * (GM**alpha) * svrs_max if chromatic: # Constants from FSIM paper, use same method for color image T3, T4, lmbda = 200, 200, 0.03 S_I = similarity_map(x_i, y_i, T3) S_Q = similarity_map(x_q, y_q, T4) score = score * torch.abs(S_I * S_Q)**lmbda # Complex gradients will work in PyTorch 1.6.0 # score = score * torch.real((S_I * S_Q).to(torch.complex64) ** lmbda) eps = torch.finfo(score.dtype).eps result = score.sum(dim=[1, 2, 3]) / (svrs_max.sum(dim=[1, 2, 3]) + eps) return _reduce(result, reduction)