def _natural_scene_statistics(luma: torch.Tensor, kernel_size: int = 7, sigma: float = 7. / 6) -> torch.Tensor: kernel = gaussian_filter(kernel_size=kernel_size, sigma=sigma).view(1, 1, kernel_size, kernel_size).to(luma) C = 1 mu = F.conv2d(luma, kernel, padding=kernel_size // 2) mu_sq = mu**2 std = F.conv2d(luma**2, kernel, padding=kernel_size // 2) std = ((std - mu_sq).abs().sqrt()) luma_nrmlzd = (luma - mu) / (std + C) alpha, sigma = _ggd_parameters(luma_nrmlzd) features = [alpha, sigma.pow(2)] shifts = [(0, 1), (1, 0), (1, 1), (-1, 1)] for shift in shifts: shifted_luma_nrmlzd = torch.roll(luma_nrmlzd, shifts=shift, dims=(-2, -1)) alpha, sigma_l, sigma_r = _aggd_parameters(luma_nrmlzd * shifted_luma_nrmlzd) eta = (sigma_r - sigma_l) * torch.exp( torch.lgamma(2. / alpha) - (torch.lgamma(1. / alpha) + torch.lgamma(3. / alpha)) / 2) features.extend((alpha, eta, sigma_l.pow(2), sigma_r.pow(2))) return torch.stack(features, dim=-1)
def _subband_similarity(x: torch.Tensor, y: torch.Tensor, first_term: bool, kernel_size: int = 3, sigma: float = 1.5, percentile: float = 0.05) -> torch.Tensor: r"""Compute similarity between 2 subbands Args: x: First input subband. Shape (N, 1, H, W). y: Second input subband. Shape (N, 1, H, W). first_term: whether this is is the first element of subband sim matrix to be calculated kernel_size: Size of gaussian kernel for computing local variance. Kernels size must be in (0, input size]. Default: 3 sigma: STD of gaussian kernel for computing local variance. Default: 1.5 percentile: % in [0,1] of worst similarity scores which should be kept. Default: 0.05 Returns: DSS: Index of similarity between two images. In [0, 1] interval. Note: This implementation is based on the original MATLAB code (see header). """ # `c` takes value of DC or AC coefficient depending on stage dc_coeff, ac_coeff = (1000, 300) c = dc_coeff if first_term else ac_coeff # Compute local variance kernel = gaussian_filter(kernel_size=kernel_size, sigma=sigma) kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x) mu_x = F.conv2d(x, kernel, padding=kernel_size // 2) mu_y = F.conv2d(y, kernel, padding=kernel_size // 2) sigma_xx = F.conv2d(x * x, kernel, padding=kernel_size // 2) - mu_x**2 sigma_yy = F.conv2d(y * y, kernel, padding=kernel_size // 2) - mu_y**2 sigma_xx[sigma_xx < 0] = 0 sigma_yy[sigma_yy < 0] = 0 left_term = (2 * torch.sqrt(sigma_xx * sigma_yy) + c) / (sigma_xx + sigma_yy + c) # Spatial pooling of worst scores percentile_index = round(percentile * (left_term.size(-2) * left_term.size(-1))) sorted_left = torch.sort(left_term.flatten(start_dim=1)).values similarity = torch.mean(sorted_left[:, :percentile_index], dim=1) # For DC, multiply by a right term if first_term: sigma_xy = F.conv2d(x * y, kernel, padding=kernel_size // 2) - mu_x * mu_y right_term = ((sigma_xy + c) / (torch.sqrt(sigma_xx * sigma_yy) + c)) sorted_right = torch.sort(right_term.flatten(start_dim=1)).values similarity *= torch.mean(sorted_right[:, :percentile_index], dim=1) return similarity
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False, k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Interface of Structural Similarity (SSIM) index. Args: x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. y: Batch of images. Required to be 2D (H, W), 3D (C,H,W) 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Value range of input images (usually 1.0 or 255). reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, full: Return cs map or not. k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. References: .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, :DOI:`10.1109/TIP.2003.819861` """ _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=None) x, y = _adjust_dimensions(input_tensors=(x, y)) if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor): x = x.type(torch.float32) y = y.type(torch.float32) kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) ssim_val = ssim_map.mean(1) cs = cs_map.mean(1) if reduction != 'none': reduction_operation = {'mean': torch.mean, 'sum': torch.sum} ssim_val = reduction_operation[reduction](ssim_val, dim=0) cs = reduction_operation[reduction](cs, dim=0) if full: return ssim_val, cs return ssim_val
def vif_p(x: torch.Tensor, y: torch.Tensor, sigma_n_sq: float = 2.0, data_range: Union[int, float] = 1.0, reduction: str = 'mean') -> torch.Tensor: r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images. This metric isn't symmetric, so make sure to place arguments in correct order. Both inputs supposed to have RGB channels order. Args: x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). sigma_n_sq: HVS model parameter (variance of the visual noise). data_range: Value range of input images (usually 1.0 or 255). Default: 1.0 reduction: Reduction over samples in batch: "mean"|"sum"|"none" Returns: VIF: Index of similarity betwen two images. Usually in [0, 1] interval. Can be bigger than 1 for predicted images with higher contrast than original one. Note: In original paper this method was used for bands in discrete wavelet decomposition. Later on authors released code to compute VIF approximation in pixel domain. See https://live.ece.utexas.edu/research/Quality/VIF.htm for details. """ _validate_input((x, y), allow_5d=False, data_range=data_range) x, y = _adjust_dimensions(input_tensors=(x, y)) min_size = 41 if x.size(-1) < min_size or x.size(-2) < min_size: raise ValueError( f'Invalid size of the input images, expected at least {min_size}x{min_size}.' ) x = x / data_range * 255 y = y / data_range * 255 # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B num_channels = x.size(1) if num_channels == 3: x = 0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:, 2, :, :] y = 0.299 * y[:, 0, :, :] + 0.587 * y[:, 1, :, :] + 0.114 * y[:, 2, :, :] # Add channel dimension x = x[:, None, :, :] y = y[:, None, :, :] # Constant for numerical stability EPS = 1e-8 # Progressively downsample images and compute VIF on different scales x_vif, y_vif = 0, 0 for scale in range(4): kernel_size = 2**(4 - scale) + 1 kernel = gaussian_filter(kernel_size, sigma=kernel_size / 5) kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x) if scale > 0: # Convolve and downsample x = F.conv2d(x, kernel)[:, :, ::2, ::2] # valid padding y = F.conv2d(y, kernel)[:, :, ::2, ::2] # valid padding mu_x, mu_y = F.conv2d(x, kernel), F.conv2d(y, kernel) # valid padding mu_x_sq, mu_y_sq, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y # Good sigma_x_sq = F.conv2d(x**2, kernel) - mu_x_sq sigma_y_sq = F.conv2d(y**2, kernel) - mu_y_sq sigma_xy = F.conv2d(x * y, kernel) - mu_xy # Zero small negative values sigma_x_sq = torch.relu(sigma_x_sq) sigma_y_sq = torch.relu(sigma_y_sq) g = sigma_xy / (sigma_y_sq + EPS) sigma_v_sq = sigma_x_sq - g * sigma_xy g = torch.where(sigma_y_sq >= EPS, g, torch.zeros_like(g)) sigma_v_sq = torch.where(sigma_y_sq >= EPS, sigma_v_sq, sigma_x_sq) sigma_y_sq = torch.where(sigma_y_sq >= EPS, sigma_y_sq, torch.zeros_like(sigma_y_sq)) g = torch.where(sigma_x_sq >= EPS, g, torch.zeros_like(g)) sigma_v_sq = torch.where(sigma_x_sq >= EPS, sigma_v_sq, torch.zeros_like(sigma_v_sq)) sigma_v_sq = torch.where(g >= 0, sigma_v_sq, sigma_x_sq) g = torch.relu(g) sigma_v_sq = torch.where(sigma_v_sq > EPS, sigma_v_sq, torch.ones_like(sigma_v_sq) * EPS) x_vif_scale = torch.log10(1.0 + (g**2.) * sigma_y_sq / (sigma_v_sq + sigma_n_sq)) x_vif = x_vif + torch.sum(x_vif_scale, dim=[1, 2, 3]) y_vif = y_vif + torch.sum(torch.log10(1.0 + sigma_y_sq / sigma_n_sq), dim=[1, 2, 3]) score: torch.Tensor = (x_vif + EPS) / (y_vif + EPS) # Reduce if needed if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
def vif_p(x: torch.Tensor, y: torch.Tensor, sigma_n_sq: float = 2.0, data_range: Union[int, float] = 1.0, reduction: str = 'mean') -> torch.Tensor: r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images. This metric isn't symmetric, so make sure to place arguments in correct order. Both inputs supposed to have RGB channels order. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. sigma_n_sq: HVS model parameter (variance of the visual noise). data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` Returns: VIF Index of similarity betwen two images. Usually in [0, 1] interval. Can be bigger than 1 for predicted :math:`x` images with higher contrast than original one. References: H. R. Sheikh and A. C. Bovik, "Image information and visual quality," IEEE Transactions on Image Processing, vol. 15, no. 2, pp. 430-444, Feb. 2006 https://ieeexplore.ieee.org/abstract/document/1576816/ DOI: 10.1109/TIP.2005.859378. Note: In original paper this method was used for bands in discrete wavelet decomposition. Later on authors released code to compute VIF approximation in pixel domain. See https://live.ece.utexas.edu/research/Quality/VIF.htm for details. """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range)) min_size = 41 if x.size(-1) < min_size or x.size(-2) < min_size: raise ValueError( f'Invalid size of the input images, expected at least {min_size}x{min_size}.' ) x = x / float(data_range) * 255 y = y / float(data_range) * 255 # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B num_channels = x.size(1) if num_channels == 3: x = 0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:, 2, :, :] y = 0.299 * y[:, 0, :, :] + 0.587 * y[:, 1, :, :] + 0.114 * y[:, 2, :, :] # Add channel dimension x = x[:, None, :, :] y = y[:, None, :, :] # Constant for numerical stability EPS = 1e-8 # Progressively downsample images and compute VIF on different scales x_vif, y_vif = 0, 0 for scale in range(4): kernel_size = 2**(4 - scale) + 1 kernel = gaussian_filter(kernel_size, sigma=kernel_size / 5) kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x) if scale > 0: # Convolve and downsample x = F.conv2d(x, kernel)[:, :, ::2, ::2] # valid padding y = F.conv2d(y, kernel)[:, :, ::2, ::2] # valid padding mu_x, mu_y = F.conv2d(x, kernel), F.conv2d(y, kernel) # valid padding mu_x_sq, mu_y_sq, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y # Good sigma_x_sq = F.conv2d(x**2, kernel) - mu_x_sq sigma_y_sq = F.conv2d(y**2, kernel) - mu_y_sq sigma_xy = F.conv2d(x * y, kernel) - mu_xy # Zero small negative values sigma_x_sq = torch.relu(sigma_x_sq) sigma_y_sq = torch.relu(sigma_y_sq) g = sigma_xy / (sigma_y_sq + EPS) sigma_v_sq = sigma_x_sq - g * sigma_xy g = torch.where(sigma_y_sq >= EPS, g, torch.zeros_like(g)) sigma_v_sq = torch.where(sigma_y_sq >= EPS, sigma_v_sq, sigma_x_sq) sigma_y_sq = torch.where(sigma_y_sq >= EPS, sigma_y_sq, torch.zeros_like(sigma_y_sq)) g = torch.where(sigma_x_sq >= EPS, g, torch.zeros_like(g)) sigma_v_sq = torch.where(sigma_x_sq >= EPS, sigma_v_sq, torch.zeros_like(sigma_v_sq)) sigma_v_sq = torch.where(g >= 0, sigma_v_sq, sigma_x_sq) g = torch.relu(g) sigma_v_sq = torch.where(sigma_v_sq > EPS, sigma_v_sq, torch.ones_like(sigma_v_sq) * EPS) x_vif_scale = torch.log10(1.0 + (g**2.) * sigma_y_sq / (sigma_v_sq + sigma_n_sq)) x_vif = x_vif + torch.sum(x_vif_scale, dim=[1, 2, 3]) y_vif = y_vif + torch.sum(torch.log10(1.0 + sigma_y_sq / sigma_n_sq), dim=[1, 2, 3]) score: torch.Tensor = (x_vif + EPS) / (y_vif + EPS) return _reduce(score, reduction)
def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, data_range: Union[int, float] = 1., reduction: str = 'mean', scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None, k1: float = 0.01, k2: float = 0.03) -> torch.Tensor: r""" Interface of Multi-scale Structural Similarity (MS-SSIM) index. Args: x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. y: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Value range of input images (usually 1.0 or 255). reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, scale_weights: Weights for different scales. If None, default weights from the paper [1] will be used. Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). k1: Algorithm parameter, K1 (small constant, see [2]). k2: Algorithm parameter, K2 (small constant, see [2]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Value of Multi-scale Structural Similarity (MS-SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. References: .. [1] Wang, Z., Simoncelli, E. P., Bovik, A. C. (2003). Multi-scale Structural Similarity for Image Quality Assessment. IEEE Asilomar Conference on Signals, Systems and Computers, 37, https://ieeexplore.ieee.org/document/1292216 :DOI:`10.1109/ACSSC.2003.1292216` .. [2] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, :DOI:`10.1109/TIP.2003.819861` """ _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=scale_weights) x, y = _adjust_dimensions(input_tensors=(x, y)) if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor): x = x.type(torch.float32) y = y.type(torch.float32) if scale_weights is None: scale_weights_from_ms_ssim_paper = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] scale_weights = scale_weights_from_ms_ssim_paper scale_weights_tensor = scale_weights if isinstance(scale_weights, torch.Tensor) else torch.tensor(scale_weights) scale_weights_tensor = scale_weights_tensor.to(y) kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) _compute_msssim = _multi_scale_ssim_complex if x.dim() == 5 else _multi_scale_ssim msssim_val = _compute_msssim( x=x, y=y, data_range=data_range, kernel=kernel, scale_weights_tensor=scale_weights_tensor, k1=k1, k2=k2 ) if reduction == 'none': return msssim_val return {'mean': torch.mean, 'sum': torch.sum}[reduction](msssim_val, dim=0)
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False, downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor]: r"""Interface of Structural Similarity (SSIM) index. Inputs supposed to be in range [0, data_range]. To match performance with skimage and tensorflow set `downsample` = True. Args: x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` full: Return cs map or not. downsample: Perform average pool before SSIM computation. Default: True k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. References: .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, DOI: `10.1109/TIP.2003.819861` """ assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) x = x.type(torch.float32) y = y.type(torch.float32) x = x / data_range y = y / data_range # Averagepool image if the size is large enough f = max(1, round(min(x.size()[-2:]) / 256)) if (f > 1) and downsample: x = F.avg_pool2d(x, kernel_size=f) y = F.avg_pool2d(y, kernel_size=f) kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim( ) == 5 else _ssim_per_channel ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) ssim_val = ssim_map.mean(1) cs = cs_map.mean(1) ssim_val = _reduce(ssim_val, reduction) cs = _reduce(cs, reduction) if full: return [ssim_val, cs] return ssim_val
def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, data_range: Union[int, float] = 1., reduction: str = 'mean', scale_weights: Optional[torch.Tensor] = None, k1: float = 0.01, k2: float = 0.03) -> torch.Tensor: r""" Interface of Multi-scale Structural Similarity (MS-SSIM) index. Inputs supposed to be in range ``[0, data_range]`` with RGB channels order for colour images. Args: x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` scale_weights: Weights for different scales. If ``None``, default weights from the paper will be used. Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). k1: Algorithm parameter, K1 (small constant). k2: Algorithm parameter, K2 (small constant). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Value of Multi-scale Structural Similarity (MS-SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. References: Wang, Z., Simoncelli, E. P., Bovik, A. C. (2003). Multi-scale Structural Similarity for Image Quality Assessment. IEEE Asilomar Conference on Signals, Systems and Computers, 37, https://ieeexplore.ieee.org/document/1292216 DOI:`10.1109/ACSSC.2003.1292216` Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, DOI: `10.1109/TIP.2003.819861` Note: The size of the image should be at least ``(kernel_size - 1) * 2 ** (levels - 1) + 1``. """ assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) x = x.type(torch.float32) y = y.type(torch.float32) x = x / data_range y = y / data_range if scale_weights is None: # Values from MS-SSIM the paper scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(x) else: # Normalize scale weights scale_weights = (scale_weights / scale_weights.sum()).to(x) kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(x) _compute_msssim = _multi_scale_ssim_complex if x.dim( ) == 5 else _multi_scale_ssim msssim_val = _compute_msssim(x=x, y=y, data_range=data_range, kernel=kernel, scale_weights=scale_weights, k1=k1, k2=k2) return _reduce(msssim_val, reduction)
def information_weighted_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1., kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, parent: bool = True, blk_size: int = 3, sigma_nsq: float = 0.4, scale_weights: Optional[torch.Tensor] = None, reduction: str = 'mean') -> torch.Tensor: r"""Interface of Information Content Weighted Structural Similarity (IW-SSIM) index. Inputs supposed to be in range ``[0, data_range]``. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution for sliding window used in comparison. k1: Algorithm parameter, K1 (small constant). k2: Algorithm parameter, K2 (small constant). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. parent: Flag to control dependency on previous layer of pyramid. blk_size: The side-length of the sliding window used in comparison for information content. sigma_nsq: Parameter of visual distortion model. scale_weights: Weights for scaling. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` Returns: Value of Information Content Weighted Structural Similarity (IW-SSIM) index. References: Wang, Zhou, and Qiang Li.. Information content weighting for perceptual image quality assessment. IEEE Transactions on image processing 20.5 (2011): 1185-1198. https://ece.uwaterloo.ca/~z70wang/publications/IWSSIM.pdf DOI:`10.1109/TIP.2010.2092435` Note: Lack of content in target image could lead to RuntimeError due to singular information content matrix, which cannot be inverted. """ assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' _validate_input(tensors=[x, y], dim_range=(4, 4), data_range=(0., data_range)) x = x / float(data_range) * 255 y = y / float(data_range) * 255 if x.size(1) == 3: x = rgb2yiq(x)[:, :1] y = rgb2yiq(y)[:, :1] if scale_weights is None: scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=x.dtype, device=x.device) scale_weights = scale_weights / scale_weights.sum() if scale_weights.size(0) != scale_weights.numel(): raise ValueError(f'Expected a vector of weights, got {scale_weights.dim()}D tensor') levels = scale_weights.size(0) min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1 if x.size(-1) < min_size or x.size(-2) < min_size: raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.') blur_pad = math.ceil((kernel_size - 1) / 2) # Ceil iw_pad = blur_pad - math.floor((blk_size - 1) / 2) # floor gauss_kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(x) # Size of the kernel size to build Laplacian pyramid pyramid_kernel_size = 5 bin_filter = binomial_filter1d(kernel_size=pyramid_kernel_size).to(x) * 2 ** 0.5 lo_x, x_diff_old = _pyr_step(x, bin_filter) lo_y, y_diff_old = _pyr_step(y, bin_filter) x = lo_x y = lo_y wmcs = [] for i in range(levels): if i < levels - 2: lo_x, x_diff = _pyr_step(x, bin_filter) lo_y, y_diff = _pyr_step(y, bin_filter) x = lo_x y = lo_y else: x_diff = x y_diff = y ssim_map, cs_map = _ssim_per_channel(x=x_diff_old, y=y_diff_old, kernel=gauss_kernel, data_range=255, k1=k1, k2=k2) if parent and i < levels - 2: iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=y_diff, kernel_size=blk_size, sigma_nsq=sigma_nsq) iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad] elif i == levels - 1: iw_map = torch.ones_like(cs_map) cs_map = ssim_map else: iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=None, kernel_size=blk_size, sigma_nsq=sigma_nsq) iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad] wmcs.append(torch.sum(cs_map * iw_map, dim=(-2, -1)) / torch.sum(iw_map, dim=(-2, -1))) x_diff_old = x_diff y_diff_old = y_diff wmcs = torch.stack(wmcs, dim=0).abs() score = torch.prod((wmcs ** scale_weights.view(-1, 1, 1)), dim=0)[:, 0] return _reduce(x=score, reduction=reduction)
def _spectral_residual_visual_saliency( x: torch.Tensor, scale: float = 0.25, kernel_size: int = 3, sigma: float = 3.8, gaussian_size: int = 10) -> torch.Tensor: r"""Compute Spectral Residual Visual Saliency Credits X. Hou and L. Zhang, CVPR 07, 2007 Reference: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.125.5641&rep=rep1&type=pdf Args: x: Tensor with shape (N, 1, H, W). scale: Resizing factor kernel_size: Kernel size of average blur filter sigma: Sigma of gaussian filter applied on saliency map gaussian_size: Size of gaussian filter applied on saliency map Returns: saliency_map: Tensor with shape BxHxW """ eps = torch.finfo(x.dtype).eps for kernel in kernel_size, gaussian_size: if x.size(-1) * scale < kernel or x.size(-2) * scale < kernel: raise ValueError( f'Kernel size can\'t be greater than actual input size. ' f'Input size: {x.size()} x {scale}. Kernel size: {kernel}') # Downsize image in_img = imresize(x, scale=scale) # Fourier transform (use complex format [a,b] instead of a + ib # because torch<1.8.0 autograd does not support the latter) recommended_torch_version = _parse_version('1.8.0') torch_version = _parse_version(torch.__version__) if len(torch_version) != 0 and torch_version >= recommended_torch_version: imagefft = torch.fft.fft2(in_img) log_amplitude = torch.log(imagefft.abs() + eps) phase = torch.angle(imagefft) else: imagefft = torch.rfft(in_img, 2, onesided=False) # Compute log of absolute value and angle of fourier transform log_amplitude = torch.log(imagefft.pow(2).sum(dim=-1).sqrt() + eps) phase = torch.atan2(imagefft[..., 1], imagefft[..., 0] + eps) # Compute spectral residual using average filtering padding = kernel_size // 2 if padding: up_pad = (kernel_size - 1) // 2 down_pad = padding pad_to_use = [up_pad, down_pad, up_pad, down_pad] # replicate padding before average filtering spectral_residual = F.pad(log_amplitude, pad=pad_to_use, mode='replicate') else: spectral_residual = log_amplitude spectral_residual = log_amplitude - F.avg_pool2d( spectral_residual, kernel_size=kernel_size, stride=1) # Saliency map # representation of complex exp(spectral_residual + j * phase) compx = torch.stack((torch.exp(spectral_residual) * torch.cos(phase), torch.exp(spectral_residual) * torch.sin(phase)), -1) if len(torch_version) != 0 and torch_version >= recommended_torch_version: saliency_map = torch.abs(torch.fft.ifft2( torch.view_as_complex(compx)))**2 else: saliency_map = torch.sum(torch.ifft(compx, 2)**2, dim=-1) # After effect for SR-SIM # Apply gaussian blur kernel = gaussian_filter(gaussian_size, sigma) if gaussian_size % 2 == 0: # matlab pads upper and lower borders with 0s for even kernels kernel = torch.cat((torch.zeros(1, 1, gaussian_size), kernel), 1) kernel = torch.cat((torch.zeros(1, gaussian_size + 1, 1), kernel), 2) gaussian_size += 1 kernel = kernel.view(1, 1, gaussian_size, gaussian_size).to(saliency_map) saliency_map = F.conv2d(saliency_map, kernel, padding=(gaussian_size - 1) // 2) # normalize between [0, 1] min_sal = torch.min(saliency_map[:]) max_sal = torch.max(saliency_map[:]) saliency_map = (saliency_map - min_sal) / (max_sal - min_sal + eps) # scale to original size saliency_map = imresize(saliency_map, sizes=x.size()[-2:]) return saliency_map