Example #1
0
def _natural_scene_statistics(luma: torch.Tensor,
                              kernel_size: int = 7,
                              sigma: float = 7. / 6) -> torch.Tensor:
    kernel = gaussian_filter(kernel_size=kernel_size,
                             sigma=sigma).view(1, 1, kernel_size,
                                               kernel_size).to(luma)
    C = 1
    mu = F.conv2d(luma, kernel, padding=kernel_size // 2)
    mu_sq = mu**2
    std = F.conv2d(luma**2, kernel, padding=kernel_size // 2)
    std = ((std - mu_sq).abs().sqrt())

    luma_nrmlzd = (luma - mu) / (std + C)
    alpha, sigma = _ggd_parameters(luma_nrmlzd)
    features = [alpha, sigma.pow(2)]

    shifts = [(0, 1), (1, 0), (1, 1), (-1, 1)]

    for shift in shifts:
        shifted_luma_nrmlzd = torch.roll(luma_nrmlzd,
                                         shifts=shift,
                                         dims=(-2, -1))
        alpha, sigma_l, sigma_r = _aggd_parameters(luma_nrmlzd *
                                                   shifted_luma_nrmlzd)
        eta = (sigma_r - sigma_l) * torch.exp(
            torch.lgamma(2. / alpha) -
            (torch.lgamma(1. / alpha) + torch.lgamma(3. / alpha)) / 2)
        features.extend((alpha, eta, sigma_l.pow(2), sigma_r.pow(2)))

    return torch.stack(features, dim=-1)
Example #2
0
def _subband_similarity(x: torch.Tensor,
                        y: torch.Tensor,
                        first_term: bool,
                        kernel_size: int = 3,
                        sigma: float = 1.5,
                        percentile: float = 0.05) -> torch.Tensor:
    r"""Compute similarity between 2 subbands

    Args:
        x: First input subband. Shape (N, 1, H, W).
        y: Second input subband. Shape (N, 1, H, W).
        first_term: whether this is is the first element of subband sim matrix to be calculated
        kernel_size: Size of gaussian kernel for computing local variance. Kernels size must be in (0, input size].
            Default: 3
        sigma: STD of gaussian kernel for computing local variance. Default: 1.5
        percentile: % in [0,1] of worst similarity scores which should be kept. Default: 0.05
    Returns:
        DSS: Index of similarity between two images. In [0, 1] interval.
    Note:
        This implementation is based on the original MATLAB code (see header).
    """
    # `c` takes value of DC or AC coefficient depending on stage
    dc_coeff, ac_coeff = (1000, 300)
    c = dc_coeff if first_term else ac_coeff

    # Compute local variance
    kernel = gaussian_filter(kernel_size=kernel_size, sigma=sigma)
    kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x)
    mu_x = F.conv2d(x, kernel, padding=kernel_size // 2)
    mu_y = F.conv2d(y, kernel, padding=kernel_size // 2)

    sigma_xx = F.conv2d(x * x, kernel, padding=kernel_size // 2) - mu_x**2
    sigma_yy = F.conv2d(y * y, kernel, padding=kernel_size // 2) - mu_y**2

    sigma_xx[sigma_xx < 0] = 0
    sigma_yy[sigma_yy < 0] = 0
    left_term = (2 * torch.sqrt(sigma_xx * sigma_yy) + c) / (sigma_xx +
                                                             sigma_yy + c)

    # Spatial pooling of worst scores
    percentile_index = round(percentile *
                             (left_term.size(-2) * left_term.size(-1)))
    sorted_left = torch.sort(left_term.flatten(start_dim=1)).values
    similarity = torch.mean(sorted_left[:, :percentile_index], dim=1)

    # For DC, multiply by a right term
    if first_term:
        sigma_xy = F.conv2d(x * y, kernel,
                            padding=kernel_size // 2) - mu_x * mu_y
        right_term = ((sigma_xy + c) / (torch.sqrt(sigma_xx * sigma_yy) + c))
        sorted_right = torch.sort(right_term.flatten(start_dim=1)).values
        similarity *= torch.mean(sorted_right[:, :percentile_index], dim=1)

    return similarity
Example #3
0
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
         data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False,
         k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Interface of Structural Similarity (SSIM) index.

    Args:
        x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
        y: Batch of images. Required to be 2D (H, W), 3D (C,H,W) 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Value range of input images (usually 1.0 or 255).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        full: Return cs map or not.
        k1: Algorithm parameter, K1 (small constant, see [1]).
        k2: Algorithm parameter, K2 (small constant, see [1]).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
        as a tensor of size 2.

    References:
        .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           :DOI:`10.1109/TIP.2003.819861`
    """
    _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=None)
    x, y = _adjust_dimensions(input_tensors=(x, y))
    if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor):
        x = x.type(torch.float32)
        y = y.type(torch.float32)

    kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
    _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
    ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2)
    ssim_val = ssim_map.mean(1)
    cs = cs_map.mean(1)

    if reduction != 'none':
        reduction_operation = {'mean': torch.mean,
                               'sum': torch.sum}
        ssim_val = reduction_operation[reduction](ssim_val, dim=0)
        cs = reduction_operation[reduction](cs, dim=0)

    if full:
        return ssim_val, cs

    return ssim_val
Example #4
0
def vif_p(x: torch.Tensor,
          y: torch.Tensor,
          sigma_n_sq: float = 2.0,
          data_range: Union[int, float] = 1.0,
          reduction: str = 'mean') -> torch.Tensor:
    r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images.
    This metric isn't symmetric, so make sure to place arguments in correct order.
    Both inputs supposed to have RGB channels order.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        sigma_n_sq: HVS model parameter (variance of the visual noise).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        
    Returns:
        VIF: Index of similarity betwen two images. Usually in [0, 1] interval.
            Can be bigger than 1 for predicted images with higher contrast than original one.
    Note:
        In original paper this method was used for bands in discrete wavelet decomposition.
        Later on authors released code to compute VIF approximation in pixel domain.
        See https://live.ece.utexas.edu/research/Quality/VIF.htm for details.
        
    """
    _validate_input((x, y), allow_5d=False, data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    min_size = 41
    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(
            f'Invalid size of the input images, expected at least {min_size}x{min_size}.'
        )

    x = x / data_range * 255
    y = y / data_range * 255

    # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
    num_channels = x.size(1)
    if num_channels == 3:
        x = 0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:,
                                                                      2, :, :]
        y = 0.299 * y[:, 0, :, :] + 0.587 * y[:, 1, :, :] + 0.114 * y[:,
                                                                      2, :, :]

        # Add channel dimension
        x = x[:, None, :, :]
        y = y[:, None, :, :]

    # Constant for numerical stability
    EPS = 1e-8

    # Progressively downsample images and compute VIF on different scales
    x_vif, y_vif = 0, 0
    for scale in range(4):
        kernel_size = 2**(4 - scale) + 1
        kernel = gaussian_filter(kernel_size, sigma=kernel_size / 5)
        kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x)

        if scale > 0:
            # Convolve and downsample
            x = F.conv2d(x, kernel)[:, :, ::2, ::2]  # valid padding
            y = F.conv2d(y, kernel)[:, :, ::2, ::2]  # valid padding

        mu_x, mu_y = F.conv2d(x, kernel), F.conv2d(y, kernel)  # valid padding
        mu_x_sq, mu_y_sq, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y

        # Good
        sigma_x_sq = F.conv2d(x**2, kernel) - mu_x_sq
        sigma_y_sq = F.conv2d(y**2, kernel) - mu_y_sq
        sigma_xy = F.conv2d(x * y, kernel) - mu_xy

        # Zero small negative values
        sigma_x_sq = torch.relu(sigma_x_sq)
        sigma_y_sq = torch.relu(sigma_y_sq)

        g = sigma_xy / (sigma_y_sq + EPS)
        sigma_v_sq = sigma_x_sq - g * sigma_xy

        g = torch.where(sigma_y_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_y_sq >= EPS, sigma_v_sq, sigma_x_sq)
        sigma_y_sq = torch.where(sigma_y_sq >= EPS, sigma_y_sq,
                                 torch.zeros_like(sigma_y_sq))

        g = torch.where(sigma_x_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_x_sq >= EPS, sigma_v_sq,
                                 torch.zeros_like(sigma_v_sq))

        sigma_v_sq = torch.where(g >= 0, sigma_v_sq, sigma_x_sq)
        g = torch.relu(g)

        sigma_v_sq = torch.where(sigma_v_sq > EPS, sigma_v_sq,
                                 torch.ones_like(sigma_v_sq) * EPS)

        x_vif_scale = torch.log10(1.0 + (g**2.) * sigma_y_sq /
                                  (sigma_v_sq + sigma_n_sq))
        x_vif = x_vif + torch.sum(x_vif_scale, dim=[1, 2, 3])
        y_vif = y_vif + torch.sum(torch.log10(1.0 + sigma_y_sq / sigma_n_sq),
                                  dim=[1, 2, 3])

    score: torch.Tensor = (x_vif + EPS) / (y_vif + EPS)

    # Reduce if needed
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Example #5
0
def vif_p(x: torch.Tensor,
          y: torch.Tensor,
          sigma_n_sq: float = 2.0,
          data_range: Union[int, float] = 1.0,
          reduction: str = 'mean') -> torch.Tensor:
    r"""Compute Visiual Information Fidelity in **pixel** domain for a batch of images.
    This metric isn't symmetric, so make sure to place arguments in correct order.
    Both inputs supposed to have RGB channels order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        sigma_n_sq: HVS model parameter (variance of the visual noise).
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        
    Returns:
        VIF Index of similarity betwen two images. Usually in [0, 1] interval.
        Can be bigger than 1 for predicted :math:`x` images with higher contrast than original one.

    References:
        H. R. Sheikh and A. C. Bovik, "Image information and visual quality,"
        IEEE Transactions on Image Processing, vol. 15, no. 2, pp. 430-444, Feb. 2006
        https://ieeexplore.ieee.org/abstract/document/1576816/
        DOI: 10.1109/TIP.2005.859378.

    Note:
        In original paper this method was used for bands in discrete wavelet decomposition.
        Later on authors released code to compute VIF approximation in pixel domain.
        See https://live.ece.utexas.edu/research/Quality/VIF.htm for details.
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    min_size = 41
    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(
            f'Invalid size of the input images, expected at least {min_size}x{min_size}.'
        )

    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
    num_channels = x.size(1)
    if num_channels == 3:
        x = 0.299 * x[:, 0, :, :] + 0.587 * x[:, 1, :, :] + 0.114 * x[:,
                                                                      2, :, :]
        y = 0.299 * y[:, 0, :, :] + 0.587 * y[:, 1, :, :] + 0.114 * y[:,
                                                                      2, :, :]

        # Add channel dimension
        x = x[:, None, :, :]
        y = y[:, None, :, :]

    # Constant for numerical stability
    EPS = 1e-8

    # Progressively downsample images and compute VIF on different scales
    x_vif, y_vif = 0, 0
    for scale in range(4):
        kernel_size = 2**(4 - scale) + 1
        kernel = gaussian_filter(kernel_size, sigma=kernel_size / 5)
        kernel = kernel.view(1, 1, kernel_size, kernel_size).to(x)

        if scale > 0:
            # Convolve and downsample
            x = F.conv2d(x, kernel)[:, :, ::2, ::2]  # valid padding
            y = F.conv2d(y, kernel)[:, :, ::2, ::2]  # valid padding

        mu_x, mu_y = F.conv2d(x, kernel), F.conv2d(y, kernel)  # valid padding
        mu_x_sq, mu_y_sq, mu_xy = mu_x * mu_x, mu_y * mu_y, mu_x * mu_y

        # Good
        sigma_x_sq = F.conv2d(x**2, kernel) - mu_x_sq
        sigma_y_sq = F.conv2d(y**2, kernel) - mu_y_sq
        sigma_xy = F.conv2d(x * y, kernel) - mu_xy

        # Zero small negative values
        sigma_x_sq = torch.relu(sigma_x_sq)
        sigma_y_sq = torch.relu(sigma_y_sq)

        g = sigma_xy / (sigma_y_sq + EPS)
        sigma_v_sq = sigma_x_sq - g * sigma_xy

        g = torch.where(sigma_y_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_y_sq >= EPS, sigma_v_sq, sigma_x_sq)
        sigma_y_sq = torch.where(sigma_y_sq >= EPS, sigma_y_sq,
                                 torch.zeros_like(sigma_y_sq))

        g = torch.where(sigma_x_sq >= EPS, g, torch.zeros_like(g))
        sigma_v_sq = torch.where(sigma_x_sq >= EPS, sigma_v_sq,
                                 torch.zeros_like(sigma_v_sq))

        sigma_v_sq = torch.where(g >= 0, sigma_v_sq, sigma_x_sq)
        g = torch.relu(g)

        sigma_v_sq = torch.where(sigma_v_sq > EPS, sigma_v_sq,
                                 torch.ones_like(sigma_v_sq) * EPS)

        x_vif_scale = torch.log10(1.0 + (g**2.) * sigma_y_sq /
                                  (sigma_v_sq + sigma_n_sq))
        x_vif = x_vif + torch.sum(x_vif_scale, dim=[1, 2, 3])
        y_vif = y_vif + torch.sum(torch.log10(1.0 + sigma_y_sq / sigma_n_sq),
                                  dim=[1, 2, 3])

    score: torch.Tensor = (x_vif + EPS) / (y_vif + EPS)

    return _reduce(score, reduction)
Example #6
0
def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
                     data_range: Union[int, float] = 1., reduction: str = 'mean',
                     scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None,
                     k1: float = 0.01, k2: float = 0.03) -> torch.Tensor:
    r""" Interface of Multi-scale Structural Similarity (MS-SSIM) index.

    Args:
        x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
            The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1.
        y: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
            The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Value range of input images (usually 1.0 or 255).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        scale_weights: Weights for different scales.
            If None, default weights from the paper [1] will be used.
            Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333).
        k1: Algorithm parameter, K1 (small constant, see [2]).
        k2: Algorithm parameter, K2 (small constant, see [2]).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Multi-scale Structural Similarity (MS-SSIM) index. In case of 5D input tensors,
        complex value is returned as a tensor of size 2.

    References:
        .. [1] Wang, Z., Simoncelli, E. P., Bovik, A. C. (2003).
           Multi-scale Structural Similarity for Image Quality Assessment.
           IEEE Asilomar Conference on Signals, Systems and Computers, 37,
           https://ieeexplore.ieee.org/document/1292216
           :DOI:`10.1109/ACSSC.2003.1292216`
        .. [2] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           :DOI:`10.1109/TIP.2003.819861`
    """
    _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=scale_weights)
    x, y = _adjust_dimensions(input_tensors=(x, y))
    if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor):
        x = x.type(torch.float32)
        y = y.type(torch.float32)

    if scale_weights is None:
        scale_weights_from_ms_ssim_paper = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
        scale_weights = scale_weights_from_ms_ssim_paper

    scale_weights_tensor = scale_weights if isinstance(scale_weights, torch.Tensor) else torch.tensor(scale_weights)
    scale_weights_tensor = scale_weights_tensor.to(y)
    kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
    
    _compute_msssim = _multi_scale_ssim_complex if x.dim() == 5 else _multi_scale_ssim
    msssim_val = _compute_msssim(
        x=x,
        y=y,
        data_range=data_range,
        kernel=kernel,
        scale_weights_tensor=scale_weights_tensor,
        k1=k1,
        k2=k2
    )

    if reduction == 'none':
        return msssim_val

    return {'mean': torch.mean,
            'sum': torch.sum}[reduction](msssim_val, dim=0)
Example #7
0
File: ssim.py Project: aiedward/piq
def ssim(x: torch.Tensor,
         y: torch.Tensor,
         kernel_size: int = 11,
         kernel_sigma: float = 1.5,
         data_range: Union[int, float] = 1.,
         reduction: str = 'mean',
         full: bool = False,
         downsample: bool = True,
         k1: float = 0.01,
         k2: float = 0.03) -> List[torch.Tensor]:
    r"""Interface of Structural Similarity (SSIM) index.
    Inputs supposed to be in range [0, data_range].
    To match performance with skimage and tensorflow set `downsample` = True.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
        y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        full: Return cs map or not.
        downsample: Perform average pool before SSIM computation. Default: True
        k1: Algorithm parameter, K1 (small constant, see [1]).
        k2: Algorithm parameter, K2 (small constant, see [1]).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
        as a tensor of size 2.

    References:
        .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           DOI: `10.1109/TIP.2003.819861`
    """
    assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
    _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))

    x = x.type(torch.float32)
    y = y.type(torch.float32)

    x = x / data_range
    y = y / data_range

    # Averagepool image if the size is large enough
    f = max(1, round(min(x.size()[-2:]) / 256))
    if (f > 1) and downsample:
        x = F.avg_pool2d(x, kernel_size=f)
        y = F.avg_pool2d(y, kernel_size=f)

    kernel = gaussian_filter(kernel_size,
                             kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
    _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim(
    ) == 5 else _ssim_per_channel
    ssim_map, cs_map = _compute_ssim_per_channel(x=x,
                                                 y=y,
                                                 kernel=kernel,
                                                 data_range=data_range,
                                                 k1=k1,
                                                 k2=k2)
    ssim_val = ssim_map.mean(1)
    cs = cs_map.mean(1)

    ssim_val = _reduce(ssim_val, reduction)
    cs = _reduce(cs, reduction)

    if full:
        return [ssim_val, cs]

    return ssim_val
Example #8
0
def multi_scale_ssim(x: torch.Tensor,
                     y: torch.Tensor,
                     kernel_size: int = 11,
                     kernel_sigma: float = 1.5,
                     data_range: Union[int, float] = 1.,
                     reduction: str = 'mean',
                     scale_weights: Optional[torch.Tensor] = None,
                     k1: float = 0.01,
                     k2: float = 0.03) -> torch.Tensor:
    r""" Interface of Multi-scale Structural Similarity (MS-SSIM) index.
    Inputs supposed to be in range ``[0, data_range]`` with RGB channels order for colour images.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
        y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        scale_weights: Weights for different scales.
            If ``None``, default weights from the paper will be used.
            Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333).
        k1: Algorithm parameter, K1 (small constant).
        k2: Algorithm parameter, K2 (small constant).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Multi-scale Structural Similarity (MS-SSIM) index. In case of 5D input tensors,
        complex value is returned as a tensor of size 2.

    References:
        Wang, Z., Simoncelli, E. P., Bovik, A. C. (2003).
        Multi-scale Structural Similarity for Image Quality Assessment.
        IEEE Asilomar Conference on Signals, Systems and Computers, 37,
        https://ieeexplore.ieee.org/document/1292216 DOI:`10.1109/ACSSC.2003.1292216`

        Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
        Image quality assessment: From error visibility to structural similarity.
        IEEE Transactions on Image Processing, 13, 600-612.
        https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
        DOI: `10.1109/TIP.2003.819861`
    Note:
        The size of the image should be at least ``(kernel_size - 1) * 2 ** (levels - 1) + 1``.
    """
    assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
    _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))

    x = x.type(torch.float32)
    y = y.type(torch.float32)

    x = x / data_range
    y = y / data_range

    if scale_weights is None:
        # Values from MS-SSIM the paper
        scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363,
                                      0.1333]).to(x)
    else:
        # Normalize scale weights
        scale_weights = (scale_weights / scale_weights.sum()).to(x)

    kernel = gaussian_filter(kernel_size,
                             kernel_sigma).repeat(x.size(1), 1, 1, 1).to(x)

    _compute_msssim = _multi_scale_ssim_complex if x.dim(
    ) == 5 else _multi_scale_ssim
    msssim_val = _compute_msssim(x=x,
                                 y=y,
                                 data_range=data_range,
                                 kernel=kernel,
                                 scale_weights=scale_weights,
                                 k1=k1,
                                 k2=k2)
    return _reduce(msssim_val, reduction)
Example #9
0
def information_weighted_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1.,
                              kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03,
                              parent: bool = True, blk_size: int = 3, sigma_nsq: float = 0.4,
                              scale_weights: Optional[torch.Tensor] = None,
                              reduction: str = 'mean') -> torch.Tensor:
    r"""Interface of Information Content Weighted Structural Similarity (IW-SSIM) index.
    Inputs supposed to be in range ``[0, data_range]``.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        data_range: Maximum value range of images (usually 1.0 or 255).
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution for sliding window used in comparison.
        k1: Algorithm parameter, K1 (small constant).
        k2: Algorithm parameter, K2 (small constant).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        parent: Flag to control dependency on previous layer of pyramid.
        blk_size: The side-length of the sliding window used in comparison for information content.
        sigma_nsq: Parameter of visual distortion model.
        scale_weights: Weights for scaling.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``

    Returns:
        Value of Information Content Weighted Structural Similarity (IW-SSIM) index.

    References:
        Wang, Zhou, and Qiang Li..
        Information content weighting for perceptual image quality assessment.
        IEEE Transactions on image processing 20.5 (2011): 1185-1198.
        https://ece.uwaterloo.ca/~z70wang/publications/IWSSIM.pdf DOI:`10.1109/TIP.2010.2092435`

    Note:
        Lack of content in target image could lead to RuntimeError due to singular information content matrix,
        which cannot be inverted.
    """
    assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'

    _validate_input(tensors=[x, y], dim_range=(4, 4), data_range=(0., data_range))

    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    if x.size(1) == 3:
        x = rgb2yiq(x)[:, :1]
        y = rgb2yiq(y)[:, :1]

    if scale_weights is None:
        scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=x.dtype, device=x.device)
    scale_weights = scale_weights / scale_weights.sum()
    if scale_weights.size(0) != scale_weights.numel():
        raise ValueError(f'Expected a vector of weights, got {scale_weights.dim()}D tensor')

    levels = scale_weights.size(0)

    min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1
    if x.size(-1) < min_size or x.size(-2) < min_size:
        raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.')

    blur_pad = math.ceil((kernel_size - 1) / 2)  # Ceil
    iw_pad = blur_pad - math.floor((blk_size - 1) / 2)  # floor
    gauss_kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(x)

    # Size of the kernel size to build Laplacian pyramid
    pyramid_kernel_size = 5
    bin_filter = binomial_filter1d(kernel_size=pyramid_kernel_size).to(x) * 2 ** 0.5

    lo_x, x_diff_old = _pyr_step(x, bin_filter)
    lo_y, y_diff_old = _pyr_step(y, bin_filter)

    x = lo_x
    y = lo_y
    wmcs = []

    for i in range(levels):
        if i < levels - 2:
            lo_x, x_diff = _pyr_step(x, bin_filter)
            lo_y, y_diff = _pyr_step(y, bin_filter)
            x = lo_x
            y = lo_y

        else:
            x_diff = x
            y_diff = y

        ssim_map, cs_map = _ssim_per_channel(x=x_diff_old, y=y_diff_old, kernel=gauss_kernel, data_range=255,
                                             k1=k1, k2=k2)

        if parent and i < levels - 2:
            iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=y_diff, kernel_size=blk_size,
                                          sigma_nsq=sigma_nsq)

            iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad]

        elif i == levels - 1:
            iw_map = torch.ones_like(cs_map)
            cs_map = ssim_map

        else:
            iw_map = _information_content(x=x_diff_old, y=y_diff_old, y_parent=None, kernel_size=blk_size,
                                          sigma_nsq=sigma_nsq)
            iw_map = iw_map[:, :, iw_pad:-iw_pad, iw_pad:-iw_pad]

        wmcs.append(torch.sum(cs_map * iw_map, dim=(-2, -1)) / torch.sum(iw_map, dim=(-2, -1)))

        x_diff_old = x_diff
        y_diff_old = y_diff

    wmcs = torch.stack(wmcs, dim=0).abs()

    score = torch.prod((wmcs ** scale_weights.view(-1, 1, 1)), dim=0)[:, 0]

    return _reduce(x=score, reduction=reduction)
Example #10
0
def _spectral_residual_visual_saliency(
        x: torch.Tensor,
        scale: float = 0.25,
        kernel_size: int = 3,
        sigma: float = 3.8,
        gaussian_size: int = 10) -> torch.Tensor:
    r"""Compute Spectral Residual Visual Saliency
    Credits X. Hou and L. Zhang, CVPR 07, 2007
    Reference:
        http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.125.5641&rep=rep1&type=pdf

    Args:
        x: Tensor with shape (N, 1, H, W).
        scale: Resizing factor
        kernel_size: Kernel size of average blur filter
        sigma: Sigma of gaussian filter applied on saliency map
        gaussian_size: Size of gaussian filter applied on saliency map
    Returns:
        saliency_map: Tensor with shape BxHxW

    """
    eps = torch.finfo(x.dtype).eps
    for kernel in kernel_size, gaussian_size:
        if x.size(-1) * scale < kernel or x.size(-2) * scale < kernel:
            raise ValueError(
                f'Kernel size can\'t be greater than actual input size. '
                f'Input size: {x.size()} x {scale}. Kernel size: {kernel}')

    # Downsize image
    in_img = imresize(x, scale=scale)

    # Fourier transform (use complex format [a,b] instead of a + ib
    # because torch<1.8.0 autograd does not support the latter)
    recommended_torch_version = _parse_version('1.8.0')
    torch_version = _parse_version(torch.__version__)
    if len(torch_version) != 0 and torch_version >= recommended_torch_version:
        imagefft = torch.fft.fft2(in_img)
        log_amplitude = torch.log(imagefft.abs() + eps)
        phase = torch.angle(imagefft)
    else:
        imagefft = torch.rfft(in_img, 2, onesided=False)
        # Compute log of absolute value and angle of fourier transform
        log_amplitude = torch.log(imagefft.pow(2).sum(dim=-1).sqrt() + eps)
        phase = torch.atan2(imagefft[..., 1], imagefft[..., 0] + eps)

    # Compute spectral residual using average filtering
    padding = kernel_size // 2
    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        # replicate padding before average filtering
        spectral_residual = F.pad(log_amplitude,
                                  pad=pad_to_use,
                                  mode='replicate')
    else:
        spectral_residual = log_amplitude

    spectral_residual = log_amplitude - F.avg_pool2d(
        spectral_residual, kernel_size=kernel_size, stride=1)
    # Saliency map
    # representation of complex exp(spectral_residual + j * phase)
    compx = torch.stack((torch.exp(spectral_residual) * torch.cos(phase),
                         torch.exp(spectral_residual) * torch.sin(phase)), -1)

    if len(torch_version) != 0 and torch_version >= recommended_torch_version:
        saliency_map = torch.abs(torch.fft.ifft2(
            torch.view_as_complex(compx)))**2

    else:
        saliency_map = torch.sum(torch.ifft(compx, 2)**2, dim=-1)

    # After effect for SR-SIM
    # Apply gaussian blur
    kernel = gaussian_filter(gaussian_size, sigma)
    if gaussian_size % 2 == 0:  # matlab pads upper and lower borders with 0s for even kernels
        kernel = torch.cat((torch.zeros(1, 1, gaussian_size), kernel), 1)
        kernel = torch.cat((torch.zeros(1, gaussian_size + 1, 1), kernel), 2)
        gaussian_size += 1
    kernel = kernel.view(1, 1, gaussian_size, gaussian_size).to(saliency_map)
    saliency_map = F.conv2d(saliency_map,
                            kernel,
                            padding=(gaussian_size - 1) // 2)

    # normalize between [0, 1]
    min_sal = torch.min(saliency_map[:])
    max_sal = torch.max(saliency_map[:])
    saliency_map = (saliency_map - min_sal) / (max_sal - min_sal + eps)

    # scale to original size
    saliency_map = imresize(saliency_map, sizes=x.size()[-2:])
    return saliency_map