Exemplo n.º 1
0
def _gmsd(x: torch.Tensor,
          y: torch.Tensor,
          t: float = 170 / (255.**2),
          alpha: float = 0.0) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation
    Both inputs supposed to be in range [0, 1] with RGB channels order.
    Args:
        x: Tensor with shape (N, 1, H, W).
        y: Tensor with shape (N, 1, H, W).
        t: Constant from the reference paper numerical stability of similarity map
        alpha: Masking coefficient for similarity masks computation

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        https://arxiv.org/pdf/1308.3052.pdf
    """

    # Compute grad direction
    kernels = torch.stack(
        [prewitt_filter(),
         prewitt_filter().transpose(-1, -2)])
    x_grad = gradient_map(x, kernels)
    y_grad = gradient_map(y, kernels)

    # Compute GMS
    gms = similarity_map(x_grad, y_grad, constant=t, alpha=alpha)
    mean_gms = torch.mean(gms, dim=[1, 2, 3], keepdims=True)

    # Compute GMSD along spatial dimensions. Shape (batch_size )
    score = torch.pow(gms - mean_gms, 2).mean(dim=[1, 2, 3]).sqrt()
    return score
Exemplo n.º 2
0
def _gmsd(prediction: torch.Tensor,
          target: torch.Tensor,
          t: float = 170 / (255.**2)) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation
    Both inputs supposed to be in range [0, 1] with RGB order.
    Args:
        prediction: Tensor of shape :math:`(N, 1, H, W)` holding an distorted grayscale image.
        target: Tensor of shape :math:`(N, 1, H, W)` holding an target grayscale image
        t: Constant from the reference paper numerical stability of similarity map

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        https://arxiv.org/pdf/1308.3052.pdf
    """

    # Compute grad direction
    kernels = torch.stack(
        [prewitt_filter(),
         prewitt_filter().transpose(-1, -2)])
    pred_grad = gradient_map(prediction, kernels)
    trgt_grad = gradient_map(target, kernels)

    # Compute GMS
    gms = similarity_map(pred_grad, trgt_grad, t)
    mean_gms = torch.mean(gms, dim=[1, 2, 3], keepdims=True)
    # Compute GMSD along spatial dimensions. Shape (batch_size )
    score = torch.pow(gms - mean_gms, 2).mean(dim=[1, 2, 3]).sqrt()
    return score
Exemplo n.º 3
0
def mdsi(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
         c1: float = 140., c2: float = 55., c3: float = 550., combination: str = 'sum', alpha: float = 0.6,
         beta: float = 0.1, gamma: float = 0.2, rho: float = 1., q: float = 0.25, o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.
    Supports greyscale and colour images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: ``'sum'`` | ``'mult'``.
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        Mean Deviation Similarity Index (MDSI) between 2 tensors.

    References:
        Nafchi, Hossein Ziaei and Shahkolaei, Atena and Hedjam, Rachid and Cheriet, Mohamed (2016).
        Mean deviation similarity index: Efficient and reliable full-reference image quality evaluator.
        IEEE Ieee Access, 4, 5579--5590.
        https://arxiv.org/pdf/1608.07433.pdf,
        DOI:`10.1109/ACCESS.2016.2604042`

    Note:
        The ratio between constants is usually equal :math:`c_3 = 4c_1 = 10c_2`

    Note:
        Both inputs are supposed to have RGB channels order in accordance with the original approach.
        Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
        channel 3 times.
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    if x.size(1) == 1:
        x = x.repeat(1, 3, 1, 1)
        y = y.repeat(1, 3, 1, 1)
        warnings.warn('The original MDSI supports only RGB images. The input images were converted to RGB by copying '
                      'the grey channel 3 times.')

    x = x / float(data_range) * 255
    y = y / float(data_range) * 255

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(x.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = pad(x, pad=pad_to_use)
        y = pad(y, pad=pad_to_use)

    x = avg_pool2d(x, kernel_size=kernel_size)
    y = avg_pool2d(y, kernel_size=kernel_size)

    x_lhm = rgb2lhm(x)
    y_lhm = rgb2lhm(y)

    kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(1, 2)]).to(x)
    gm_x = gradient_map(x_lhm[:, :1], kernels)
    gm_y = gradient_map(y_lhm[:, :1], kernels)
    gm_avg = gradient_map((x_lhm[:, :1] + y_lhm[:, :1]) / 2., kernels)

    gs_x_y = similarity_map(gm_x, gm_y, c1)
    gs_x_average = similarity_map(gm_x, gm_avg, c2)
    gs_y_average = similarity_map(gm_y, gm_avg, c2)

    gs_total = gs_x_y + gs_x_average - gs_y_average

    cs_total = (2 * (x_lhm[:, 1:2] * y_lhm[:, 1:2] +
                     x_lhm[:, 2:] * y_lhm[:, 2:]) + c3) / (x_lhm[:, 1:2] ** 2 +
                                                           y_lhm[:, 1:2] ** 2 +
                                                           x_lhm[:, 2:] ** 2 +
                                                           y_lhm[:, 2:] ** 2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]), dim=-1)
    else:
        raise ValueError(f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) - mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score ** rho).mean(dim=(-1, -2)) ** (o / rho)).squeeze(1)
    return _reduce(score, reduction)
Exemplo n.º 4
0
def mdsi(x: torch.Tensor,
         y: torch.Tensor,
         data_range: Union[int, float] = 1.,
         reduction: str = 'mean',
         c1: float = 140.,
         c2: float = 55.,
         c3: float = 550.,
         combination: str = 'sum',
         alpha: float = 0.6,
         beta: float = 0.1,
         gamma: float = 0.2,
         rho: float = 1.,
         q: float = 0.25,
         o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.

    Note:
        Both inputs are supposed to have RGB channels order.
        Greyscale images converted to RGB by copying the grey channel 3 times.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y:Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly

    Note:
        The ratio between constants is usually equal c3 = 4c1 = 10c2
    """
    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    if x.size(1) == 1:
        x = x.repeat(1, 3, 1, 1)
        y = y.repeat(1, 3, 1, 1)
        warnings.warn(
            'The original MDSI supports only RGB images. The input images were converted to RGB by copying '
            'the grey channel 3 times.')

    x = x / data_range * 255
    y = y / data_range * 255

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(x.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x = pad(x, pad=pad_to_use)
        y = pad(y, pad=pad_to_use)

    x = avg_pool2d(x, kernel_size=kernel_size)
    y = avg_pool2d(y, kernel_size=kernel_size)

    x_lhm = rgb2lhm(x)
    y_lhm = rgb2lhm(y)

    kernels = torch.stack([prewitt_filter(),
                           prewitt_filter().transpose(1, 2)]).to(x)
    gm_x = gradient_map(x_lhm[:, :1], kernels)
    gm_y = gradient_map(y_lhm[:, :1], kernels)
    gm_avg = gradient_map((x_lhm[:, :1] + y_lhm[:, :1]) / 2., kernels)

    gs_x_y = similarity_map(gm_x, gm_y, c1)
    gs_x_average = similarity_map(gm_x, gm_avg, c2)
    gs_y_average = similarity_map(gm_y, gm_avg, c2)

    gs_total = gs_x_y + gs_x_average - gs_y_average

    cs_total = (2 *
                (x_lhm[:, 1:2] * y_lhm[:, 1:2] + x_lhm[:, 2:] * y_lhm[:, 2:]) +
                c3) / (x_lhm[:, 1:2]**2 + y_lhm[:, 1:2]**2 + x_lhm[:, 2:]**2 +
                       y_lhm[:, 2:]**2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]),
                          dim=-1)
    else:
        raise ValueError(
            f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(
        dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) -
             mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score**rho).mean(dim=(-1, -2))**(o / rho)).squeeze(1)
    if reduction == 'none':
        return score
    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Exemplo n.º 5
0
def mdsi(prediction: torch.Tensor, target: torch.Tensor, data_range: Union[int, float] = 1., reduction: str = 'mean',
         c1: float = 140., c2: float = 55., c3: float = 550., combination: str = 'sum', alpha: float = 0.6,
         beta: float = 0.1, gamma: float = 0.2, rho: float = 1., q: float = 0.25, o: float = 0.25):
    r"""Compute Mean Deviation Similarity Index (MDSI) for a batch of images.

    Note:
        Both inputs are supposed to have RGB order in accordance with the original approach.
        Nevertheless, the method supports greyscale images, which are converted to RGB by copying the grey
        channel 3 times.

    Args:
        prediction: Batch of predicted (distorted) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W),
        channels first.
        target: Batch of target (reference) images. Required to be 2D (H,W), 3D (C,H,W), 4D (N,C,H,W), channels first.
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        c1: coefficient to calculate gradient similarity. Default: 140.
        c2: coefficient to calculate gradient similarity. Default: 55.
        c3: coefficient to calculate chromaticity similarity. Default: 550.
        combination: mode to combine gradient similarity and chromaticity similarity: "sum"|"mult".
        alpha: coefficient to combine gradient similarity and chromaticity similarity using summation.
        beta: power to combine gradient similarity with chromaticity similarity using multiplication.
        gamma: to combine gradient similarity and chromaticity similarity using multiplication.
        rho: order of the Minkowski distance
        q: coefficient to adjusts the emphasis of the values in image and MCT
        o: the power pooling applied on the final value of the deviation

    Returns:
        torch.Tensor: the batch of Mean Deviation Similarity Index (MDSI) score reduced accordingly

    Note:
        The ratio between constants is usually equal c3 = 4c1 = 10c2
    """
    _validate_input(input_tensors=(prediction, target), allow_5d=False)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    if prediction.size(1) == 1:
        prediction = prediction.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
        warnings.warn('The original MDSI supports only RGB images. The input images were converted to RGB by copying '
                      'the grey channel 3 times.')

    prediction = prediction * 255. / data_range
    target = target * 255. / data_range

    # Averaging image if the size is large enough
    kernel_size = max(1, round(min(prediction.size()[-2:]) / 256))
    padding = kernel_size // 2

    if padding:
        up_pad = (kernel_size - 1) // 2
        down_pad = padding
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        prediction = pad(prediction, pad=pad_to_use)
        target = pad(target, pad=pad_to_use)

    prediction = avg_pool2d(prediction, kernel_size=kernel_size)
    target = avg_pool2d(target, kernel_size=kernel_size)

    prediction_lhm = rgb2lhm(prediction)
    target_lhm = rgb2lhm(target)

    kernels = torch.stack([prewitt_filter(), prewitt_filter().transpose(1, 2)]).to(prediction)
    gm_prediction = gradient_map(prediction_lhm[:, :1], kernels)
    gm_target = gradient_map(target_lhm[:, :1], kernels)
    gm_avg = gradient_map((prediction_lhm[:, :1] + target_lhm[:, :1]) / 2., kernels)

    gs_prediction_target = similarity_map(gm_prediction, gm_target, c1)
    gs_prediction_average = similarity_map(gm_prediction, gm_avg, c2)
    gs_target_average = similarity_map(gm_target, gm_avg, c2)

    gs_total = gs_prediction_target + gs_prediction_average - gs_target_average

    cs_total = (2 * (prediction_lhm[:, 1:2] * target_lhm[:, 1:2] +
                     prediction_lhm[:, 2:] * target_lhm[:, 2:]) + c3) / (prediction_lhm[:, 1:2] ** 2 +
                                                                         target_lhm[:, 1:2] ** 2 +
                                                                         prediction_lhm[:, 2:] ** 2 +
                                                                         target_lhm[:, 2:] ** 2 + c3)

    if combination == 'sum':
        gcs = (alpha * gs_total + (1 - alpha) * cs_total)
    elif combination == 'mult':
        gs_total_pow = pow_for_complex(base=gs_total, exp=gamma)
        cs_total_pow = pow_for_complex(base=cs_total, exp=beta)
        gcs = torch.stack((gs_total_pow[..., 0] * cs_total_pow[..., 0],
                           gs_total_pow[..., 1] + cs_total_pow[..., 1]), dim=-1)
    else:
        raise ValueError(f'Expected combination method "sum" or "mult", got {combination}')

    mct_complex = pow_for_complex(base=gcs, exp=q)
    mct_complex = mct_complex.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # split to increase precision
    score = (pow_for_complex(base=gcs, exp=q) - mct_complex).pow(2).sum(dim=-1).sqrt()
    score = ((score ** rho).mean(dim=(-1, -2)) ** (o / rho)).squeeze(1)
    if reduction == 'none':
        return score
    return {'mean': score.mean,
            'sum': score.sum}[reduction](dim=0)