Ejemplo n.º 1
0
    def compute_metric(self, x_features: torch.Tensor,
                       y_features: torch.Tensor) -> torch.Tensor:
        r"""
        Fits multivariate Gaussians: :math:`X \sim \mathcal{N}(\mu_x, \sigma_x)` and
        :math:`Y \sim \mathcal{N}(\mu_y, \sigma_y)` to image stacks.
        Then computes FID as :math:`d^2 = ||\mu_x - \mu_y||^2 + Tr(\sigma_x + \sigma_y - 2\sqrt{\sigma_x \sigma_y})`.

        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D)`

        Returns:
            The Frechet Distance.
        """
        _validate_input([x_features, y_features],
                        dim_range=(2, 2),
                        size_range=(1, 2))
        # GPU -> CPU
        mu_x, sigma_x = _compute_statistics(
            x_features.detach().to(dtype=torch.float64))
        mu_y, sigma_y = _compute_statistics(
            y_features.detach().to(dtype=torch.float64))

        score = _compute_fid(mu_x, sigma_x, mu_y, sigma_y)

        return score
Ejemplo n.º 2
0
    def compute_metric(self, x_features: torch.Tensor,
                       y_features: torch.Tensor) -> torch.Tensor:
        r"""Implements Algorithm 2 from the paper.

        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D)`

        Returns:
            score: Scalar value of the distance between distributions.
        """
        _validate_input([x_features, y_features],
                        dim_range=(2, 2),
                        size_range=(1, 2))
        with Pool(self.num_workers) as p:
            self.features = x_features.detach().cpu().numpy()
            pool_results = p.map(self._relative_living_times,
                                 range(self.num_iters))
            mean_rlt_x = np.vstack(pool_results).mean(axis=0)

            self.features = y_features.detach().cpu().numpy()
            pool_results = p.map(self._relative_living_times,
                                 range(self.num_iters))
            mean_rlt_y = np.vstack(pool_results).mean(axis=0)

        score = np.sum((mean_rlt_x - mean_rlt_y)**2)

        return torch.tensor(score, device=x_features.device) * 1000
Ejemplo n.º 3
0
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        r"""Computation of Content loss between feature representations of prediction :math:`x` and
        target :math:`y` tensors.

        Args:
            x: An input tensor. Shape :math:`(N, C, H, W)`.
            y: A target tensor. Shape :math:`(N, C, H, W)`.

        Returns:
            Content loss between feature representations
        """
        _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))

        self.model.to(x)
        x_features = self.get_features(x)
        y_features = self.get_features(y)

        distances = self.compute_distance(x_features, y_features)

        # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
        loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3])
                          for d, w in zip(distances, self.weights)],
                         dim=1).sum(dim=1)

        return _reduce(loss, self.reduction)
Ejemplo n.º 4
0
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Forward pass a batch of square patches with shape :math:`(N, C, F, F)`.

        Returns:
            features: Concatenation of model features from different scales
            x11: Outputs of the last convolutional layer used as weights
        """
        _validate_input([
            x,
        ], dim_range=(4, 4), data_range=(0, -1))

        assert x.shape[2] == x.shape[3] == self.FEATURES, \
            f"Expected square input with shape {self.FEATURES, self.FEATURES}, got {x.shape}"

        # conv1 -> relu -> conv2 -> relu -> pool -> conv3 -> relu
        x3 = F.relu(
            self.conv3(self.pool(F.relu(self.conv2(F.relu(self.conv1(x)))))))
        # conv4 -> relu -> pool -> conv5 -> relu
        x5 = F.relu(self.conv5(self.pool(F.relu(self.conv4(x3)))))
        # conv6 -> relu -> pool -> conv7 -> relu
        x7 = F.relu(self.conv7(self.pool(F.relu(self.conv6(x5)))))
        # conv8 -> relu -> pool -> conv9 -> relu
        x9 = F.relu(self.conv9(self.pool(F.relu(self.conv8(x7)))))
        # conv10 -> relu -> pool1-> conv11 -> relU
        x11 = self.flatten(
            F.relu(self.conv11(self.pool(F.relu(self.conv10(x9))))))
        # flatten and concatenate
        features = torch.cat((self.flatten(x3), self.flatten(x5),
                              self.flatten(x7), self.flatten(x9), x11),
                             dim=1)
        return features, x11
Ejemplo n.º 5
0
def test_breaks_if_too_many_tensors_provided(tensor_2d: torch.Tensor) -> None:
    max_number_of_tensors = 2
    for n_tensors in range(max_number_of_tensors + 1,
                           (max_number_of_tensors + 1) * 10):
        inp = tuple(tensor_2d.clone() for _ in range(n_tensors))
        with pytest.raises(AssertionError):
            _validate_input(inp, allow_5d=False)
Ejemplo n.º 6
0
def test_works_on_single_5d_tensor(tensor_5d: torch.Tensor) -> None:
    try:
        _validate_input([
            tensor_5d,
        ], dim_range=(5, 5))
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")
Ejemplo n.º 7
0
def total_variation(x: torch.Tensor, reduction: str = 'mean', norm_type: str = 'l2') -> torch.Tensor:
    r"""Compute Total Variation metric

    Args:
        x: Tensor. Shape :math:`(N, C, H, W)`.
         reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic  or anisotropic.

    Returns:
        score : Total variation of a given tensor

    References:
        https://www.wikiwand.com/en/Total_variation_denoising
        https://remi.flamary.com/demos/proxtv.html
    """
    _validate_input([x, ], dim_range=(4, 4), data_range=(0, -1))

    if norm_type == 'l1':
        w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]), dim=[1, 2, 3])
        h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]), dim=[1, 2, 3])
        score = (h_variance + w_variance)
    elif norm_type == 'l2':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3])
        score = torch.sqrt(h_variance + w_variance)
    elif norm_type == 'l2_squared':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3])
        score = (h_variance + w_variance)
    else:
        raise ValueError("Incorrect norm type, should be one of {'l1', 'l2', 'l2_squared'}")

    return _reduce(score, reduction)
Ejemplo n.º 8
0
def test_breaks_if_scale_weight_wrong_n_dims_provided(
        tensor_2d: torch.Tensor) -> None:
    wrong_scale_weights = tensor_2d.clone()
    with pytest.raises(AssertionError):
        _validate_input(tensor_2d,
                        allow_5d=False,
                        scale_weights=wrong_scale_weights)
Ejemplo n.º 9
0
    def compute_metric(self, x_features: torch.Tensor,
                       y_features: torch.Tensor) -> torch.Tensor:
        r"""
        Fits multivariate Gaussians: X ~ N(mu_x, sigma_x) and Y ~ N(mu_y, sigma_y) to image stacks.
        Then computes FID as d^2 = ||mu_x - mu_y||^2 + Tr(sigma_x + sigma_y - 2*sqrt(sigma_x * sigma_y)).

        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D)`

        Returns:
        --   : The Frechet Distance.
        """
        _validate_input([x_features, y_features],
                        dim_range=(2, 2),
                        size_range=(1, 2))
        # GPU -> CPU
        mu_x, sigma_x = _compute_statistics(
            x_features.detach().to(dtype=torch.float64))
        mu_y, sigma_y = _compute_statistics(
            y_features.detach().to(dtype=torch.float64))

        score = _compute_fid(mu_x, sigma_x, mu_y, sigma_y)

        return score
Ejemplo n.º 10
0
    def forward(self, prediction: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        r"""Computation of Content loss between feature representations of prediction and target tensors.
        Args:
            prediction: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
            target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        """
        _validate_input(input_tensors=(prediction, target),
                        allow_5d=False,
                        allow_negative=True)
        prediction, target = _adjust_dimensions(input_tensors=(prediction,
                                                               target))

        self.model.to(prediction)
        prediction_features = self.get_features(prediction)
        target_features = self.get_features(target)

        distances = self.compute_distance(prediction_features, target_features)

        # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
        loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3])
                          for d, w in zip(distances, self.weights)],
                         dim=1).sum(dim=1)

        if self.reduction == 'none':
            return loss

        return {'mean': loss.mean, 'sum': loss.sum}[self.reduction](dim=0)
Ejemplo n.º 11
0
    def compute_metric(self, real_features: torch.Tensor, fake_features: torch.Tensor) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Creates non-parametric representations of the manifolds of real and generated data and computes
        the precision and recall between them.

        Args:
            real_features: Samples from data distribution. Shape :math:`(N_x, D)`
            fake_features: Samples from fake distribution. Shape :math:`(N_x, D)`
        Returns:
            Scalar value of the precision of the generated images.
            
            Scalar value of the recall of the generated images.
        """
        _validate_input([real_features, fake_features],
                        dim_range=(2, 2),
                        size_range=(1, 2))
        real_nearest_neighbour_distances = _compute_nearest_neighbour_distances(
            real_features, self.nearest_k)
        fake_nearest_neighbour_distances = _compute_nearest_neighbour_distances(
            fake_features, self.nearest_k)
        distance_real_fake = _compute_pairwise_distance(
            real_features, fake_features)

        precision = (distance_real_fake <
                     real_nearest_neighbour_distances.unsqueeze(1)).any(
                         dim=0).float().mean()

        recall = (distance_real_fake <
                  fake_nearest_neighbour_distances.unsqueeze(0)).any(
                      dim=1).float().mean()

        return precision, recall
Ejemplo n.º 12
0
def brisque(x: torch.Tensor,
            kernel_size: int = 7,
            kernel_sigma: float = 7 / 6,
            data_range: Union[int, float] = 1.,
            reduction: str = 'mean',
            interpolation: str = 'nearest') -> torch.Tensor:
    r"""Interface of BRISQUE index.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). RGB channel order for colour images.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Maximum value range of input images (usually 1.0 or 255).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none".
        interpolation: Interpolation to be used for scaling.

    Returns:
        Value of BRISQUE index.

    Note:
        The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation.
        Update the torch and torchvision to the latest versions.

    References:
        .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
        https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
    """
    if '1.5.0' in torch.__version__:
        warnings.warn(
            f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.'
            f'Update torch to the latest version to access full functionality of the BRIQSUE.'
            f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and'
            f'https://github.com/pytorch/pytorch/issues/38869.')

    _validate_input(input_tensors=x, allow_5d=False, kernel_size=kernel_size)
    x = _adjust_dimensions(input_tensors=x)

    assert data_range >= x.max(
    ), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.'
    x = x * 255. / data_range

    if x.size(1) == 3:
        x = rgb2yiq(x)[:, :1]
    features = []
    num_of_scales = 2
    for _ in range(num_of_scales):
        features.append(_natural_scene_statistics(x, kernel_size,
                                                  kernel_sigma))
        x = F.interpolate(x,
                          size=(x.size(2) // 2, x.size(3) // 2),
                          mode=interpolation)

    features = torch.cat(features, dim=-1)
    scaled_features = _scale_features(features)
    score = _score_svr(scaled_features)
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Ejemplo n.º 13
0
def brisque(x: torch.Tensor,
            kernel_size: int = 7,
            kernel_sigma: float = 7 / 6,
            data_range: Union[int, float] = 1.,
            reduction: str = 'mean',
            interpolation: str = 'nearest') -> torch.Tensor:
    r"""Interface of BRISQUE index.
    Supports greyscale and colour images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
        interpolation: Interpolation to be used for scaling.

    Returns:
        Value of BRISQUE index.

    Note:
        The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation.
        Update the torch and torchvision to the latest versions.

    References:
        .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
           https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
    """
    if '1.5.0' in torch.__version__:
        warnings.warn(
            f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.'
            f'Update torch to the latest version to access full functionality of the BRIQSUE.'
            f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and'
            f'https://github.com/pytorch/pytorch/issues/38869.')

    assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]'
    _validate_input([
        x,
    ], dim_range=(4, 4), data_range=(0, data_range))

    x = x / data_range * 255

    if x.size(1) == 3:
        x = rgb2yiq(x)[:, :1]
    features = []
    num_of_scales = 2
    for _ in range(num_of_scales):
        features.append(_natural_scene_statistics(x, kernel_size,
                                                  kernel_sigma))
        x = F.interpolate(x,
                          size=(x.size(2) // 2, x.size(3) // 2),
                          mode=interpolation)

    features = torch.cat(features, dim=-1)
    scaled_features = _scale_features(features)
    score = _score_svr(scaled_features)

    return _reduce(score, reduction)
Ejemplo n.º 14
0
def gmsd(
    x: torch.Tensor,
    y: torch.Tensor,
    reduction: str = 'mean',
    data_range: Union[int, float] = 1.,
    t: float = 170 / (255.**2)
) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation.

    Inputs supposed to be in range [0, data_range] with RGB channels order for colour images.

    Args:
        x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        t: Constant from the reference paper numerical stability of similarity map.

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013)
        https://arxiv.org/pdf/1308.3052.pdf
    """

    _validate_input(input_tensors=(x, y),
                    allow_5d=False,
                    scale_weights=None,
                    data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Rescale
    x = x / data_range
    y = y / data_range

    num_channels = x.size(1)
    if num_channels == 3:
        x = rgb2yiq(x)[:, :1]
        y = rgb2yiq(y)[:, :1]
    up_pad = 0
    down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
    pad_to_use = [up_pad, down_pad, up_pad, down_pad]
    x = F.pad(x, pad=pad_to_use)
    y = F.pad(y, pad=pad_to_use)

    x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
    y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0)

    score = _gmsd(x=x, y=y, t=t)
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Ejemplo n.º 15
0
def test_breaks_if_not_supported_data_types_provided() -> None:
    inputs_of_wrong_types = [[], (), {}, 42, '42', True, 42.,
                             np.array([42]), None]
    for inp in inputs_of_wrong_types:
        with pytest.raises(AssertionError):
            _validate_input([
                inp,
            ])
Ejemplo n.º 16
0
    def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) \
            -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Computes KID (polynomial MMD) for given sets of features, obtained from Inception net
        or any other feature extractor.
        Samples must be in range [0, 1].

        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D)`

        Returns:
            KID score and variance (optional).
        """
        _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2))
        var_at_m = min(x_features.size(0), y_features.size(0))
        if self.subset_size is None:
            subset_size = x_features.size(0)
        else:
            subset_size = self.subset_size

        results = []
        for _ in range(self.n_subsets):
            x_subset = x_features[torch.randperm(len(x_features))[:subset_size]]
            y_subset = y_features[torch.randperm(len(y_features))[:subset_size]]

            # use  k(x, y) = (gamma <x, y> + coef0)^degree
            # default gamma is 1 / dim
            K_XX = _polynomial_kernel(
                x_subset,
                None,
                degree=self.degree,
                gamma=self.gamma,
                coef0=self.coef0)
            K_YY = _polynomial_kernel(
                y_subset,
                None,
                degree=self.degree,
                gamma=self.gamma,
                coef0=self.coef0)
            K_XY = _polynomial_kernel(
                x_subset,
                y_subset,
                degree=self.degree,
                gamma=self.gamma,
                coef0=self.coef0)

            out = _mmd2_and_variance(K_XX, K_XY, K_YY, var_at_m=var_at_m, ret_var=self.ret_var)
            results.append(out)

        if self.ret_var:
            score = torch.mean(torch.stack([p[0] for p in results], dim=0))
            variance = torch.mean(torch.stack([p[1] for p in results], dim=0))
            return (score, variance)
        else:
            score = torch.mean(torch.stack(results, dim=0))
            return score
Ejemplo n.º 17
0
def gmsd(
    prediction: torch.Tensor,
    target: torch.Tensor,
    reduction: Optional[str] = 'mean',
    data_range: Union[int, float] = 1.,
    t: float = 170 / (255.**2)
) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation
    Both inputs supposed to be in range [0, 1] with RGB order.
    Args:
        prediction: Tensor of shape :math:`(N, C, H, W)` holding an distorted image.
        target: Tensor of shape :math:`(N, C, H, W)` holding an target image
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        t: Constant from the reference paper numerical stability of similarity map.

    Returns:
        gmsd : Gradient Magnitude Similarity Deviation between given tensors.

    References:
        https://arxiv.org/pdf/1308.3052.pdf
    """

    _validate_input(input_tensors=(prediction, target),
                    allow_5d=False,
                    scale_weights=None)
    prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

    prediction = prediction / float(data_range)
    target = target / float(data_range)

    num_channels = prediction.size(1)
    if num_channels == 3:
        prediction = rgb2yiq(prediction)[:, :1]
        target = rgb2yiq(target)[:, :1]
    up_pad = 0
    down_pad = max(prediction.shape[2] % 2, prediction.shape[3] % 2)
    pad_to_use = [up_pad, down_pad, up_pad, down_pad]
    prediction = F.pad(prediction, pad=pad_to_use)
    target = F.pad(target, pad=pad_to_use)

    prediction = F.avg_pool2d(prediction, kernel_size=2, stride=2, padding=0)
    target = F.avg_pool2d(target, kernel_size=2, stride=2, padding=0)

    score = _gmsd(prediction=prediction, target=target, t=t)
    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Ejemplo n.º 18
0
    def compute_metric(self, x_features: torch.Tensor,
                       y_features: torch.Tensor) -> torch.Tensor:
        r"""Compute MSID score between two sets of samples.
        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D)`
            ts: Temperature values.
            k: Number of neighbours for graph construction.
            m: Lanczos steps in SLQ.
            niters: Number of starting random vectors for SLQ.
            rademacher: True to use Rademacher distribution,
            False - standard normal for random vectors in Hutchinson.
            msid_mode: 'l2' to compute the l2 norm of the distance between `msid1` and `msid2`;
            'max' to find the maximum abosulute difference between two descriptors over temperature
            normalized_laplacian: if True, use normalized Laplacian.
            normalize: 'empty' for average heat kernel (corresponds to the empty graph normalization of NetLSD),
            'complete' for the complete, 'er' for erdos-renyi normalization, 'none' for no normalization

        Returns:
            score: Scalar value of the distance between distributions.
        """
        _validate_input([x_features, y_features],
                        dim_range=(2, 2),
                        size_range=(1, 2))
        normed_msid_x = _msid_descriptor(
            x_features.detach().cpu().numpy(),
            ts=self.ts,
            k=self.k,
            m=self.m,
            niters=self.niters,
            rademacher=self.rademacher,
            normalized_laplacian=self.normalized_laplacian,
            normalize=self.normalize)
        normed_msid_y = _msid_descriptor(
            y_features.detach().cpu().numpy(),
            ts=self.ts,
            k=self.k,
            m=self.m,
            niters=self.niters,
            rademacher=self.rademacher,
            normalized_laplacian=self.normalized_laplacian,
            normalize=self.normalize)

        c = np.exp(-2 * (self.ts + 1 / self.ts))
        if self.msid_mode == 'l2':
            score = np.linalg.norm(normed_msid_x - normed_msid_y)
        elif self.msid_mode == 'max':
            score = np.amax(c * np.abs(normed_msid_x - normed_msid_y))
        else:
            raise ValueError('Mode must be in {`l2`, `max`}')

        return torch.tensor(score, device=x_features.device)
Ejemplo n.º 19
0
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5,
         data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False,
         k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Interface of Structural Similarity (SSIM) index.

    Args:
        x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
        y: Batch of images. Required to be 2D (H, W), 3D (C,H,W) 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first.
        kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
        kernel_sigma: Sigma of normal distribution.
        data_range: Value range of input images (usually 1.0 or 255).
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        full: Return cs map or not.
        k1: Algorithm parameter, K1 (small constant, see [1]).
        k2: Algorithm parameter, K2 (small constant, see [1]).
            Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

    Returns:
        Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
        as a tensor of size 2.

    References:
        .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
           (2004). Image quality assessment: From error visibility to
           structural similarity. IEEE Transactions on Image Processing,
           13, 600-612.
           https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
           :DOI:`10.1109/TIP.2003.819861`
    """
    _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=None)
    x, y = _adjust_dimensions(input_tensors=(x, y))
    if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor):
        x = x.type(torch.float32)
        y = y.type(torch.float32)

    kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
    _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
    ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2)
    ssim_val = ssim_map.mean(1)
    cs = cs_map.mean(1)

    if reduction != 'none':
        reduction_operation = {'mean': torch.mean,
                               'sum': torch.sum}
        ssim_val = reduction_operation[reduction](ssim_val, dim=0)
        cs = reduction_operation[reduction](cs, dim=0)

    if full:
        return ssim_val, cs

    return ssim_val
Ejemplo n.º 20
0
def test_works_on_two_not_5d_tensors(tensor_1d: torch.Tensor) -> None:
    tensor = tensor_1d.clone()
    max_num_dims = 10
    for _ in range(max_num_dims):
        another_tensor = tensor.clone()
        if 1 < tensor.dim() < 5:
            try:
                _validate_input([tensor, another_tensor], dim_range=(2, 4))
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                _validate_input([tensor, another_tensor], dim_range=(2, 4))

        tensor.unsqueeze_(0)
Ejemplo n.º 21
0
def test_works_on_single_not_5d_tensor(tensor_1d: torch.Tensor) -> None:
    tensor = tensor_1d.clone()
    # 1D -> max_num_dims
    max_num_dims = 10
    for _ in range(max_num_dims):
        if 1 < tensor.dim() < 5:
            try:
                _validate_input(tensor, allow_5d=False)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                _validate_input(tensor, allow_5d=False)

        tensor.unsqueeze_(0)
Ejemplo n.º 22
0
def psnr(x: torch.Tensor,
         y: torch.Tensor,
         data_range: Union[int, float] = 1.0,
         reduction: str = 'mean',
         convert_to_greyscale: bool = False) -> torch.Tensor:
    r"""Compute Peak Signal-to-Noise Ratio for a batch of images.
    Supports both greyscale and color images with RGB channel order.

    Args:
        x: Predicted images set :math:`x`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        y: Target images set :math:`y`.
            Shape (H, W), (C, H, W) or (N, C, H, W).
        data_range: Value range of input images (usually 1.0 or 255). Default: 1.0
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        convert_to_greyscale: Convert RGB image to YCbCr format and computes PSNR
            only on luminance channel if `True`. Compute on all 3 channels otherwise.

    Returns:
        PSNR: Index of similarity betwen two images.

    References:
        https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
    """
    _validate_input((x, y), allow_5d=False, data_range=data_range)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Constant for numerical stability
    EPS = 1e-8

    x = x / data_range
    y = y / data_range

    if (x.size(1) == 3) and convert_to_greyscale:
        # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
        rgb_to_grey = torch.tensor([0.299, 0.587, 0.114]).view(1, -1, 1,
                                                               1).to(x)
        x = torch.sum(x * rgb_to_grey, dim=1, keepdim=True)
        y = torch.sum(y * rgb_to_grey, dim=1, keepdim=True)

    mse = torch.mean((x - y)**2, dim=[1, 2, 3])
    score: torch.Tensor = -10 * torch.log10(mse + EPS)

    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Ejemplo n.º 23
0
Archivo: tv.py Proyecto: Linaom1214/piq
def total_variation(x: torch.Tensor,
                    reduction: str = 'mean',
                    norm_type: str = 'l2') -> torch.Tensor:
    r"""Compute Total Variation metric

    Args:
        x: Tensor with shape (N, C, H, W).
        reduction: Reduction over samples in batch: "mean"|"sum"|"none"
        norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic  or anisotropic.

    Returns:
        score : Total variation of a given tensor

    References:
        https://www.wikiwand.com/en/Total_variation_denoising
        https://remi.flamary.com/demos/proxtv.html
    """
    _validate_input(x, allow_5d=False)
    x = _adjust_dimensions(x)

    if norm_type == 'l1':
        w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]),
                               dim=[1, 2, 3])
        h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]),
                               dim=[1, 2, 3])
        score = (h_variance + w_variance)
    elif norm_type == 'l2':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2),
                               dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2),
                               dim=[1, 2, 3])
        score = torch.sqrt(h_variance + w_variance)
    elif norm_type == 'l2_squared':
        w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2),
                               dim=[1, 2, 3])
        h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2),
                               dim=[1, 2, 3])
        score = (h_variance + w_variance)
    else:
        raise ValueError(
            "Incorrect reduction type, should be one of {'l1', 'l2', 'l2_squared'}"
        )

    if reduction == 'none':
        return score

    return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
Ejemplo n.º 24
0
def gmsd(
    x: torch.Tensor,
    y: torch.Tensor,
    reduction: str = 'mean',
    data_range: Union[int, float] = 1.,
    t: float = 170 / (255.**2)
) -> torch.Tensor:
    r"""Compute Gradient Magnitude Similarity Deviation.

    Supports greyscale and colour images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        data_range: Maximum value range of images (usually 1.0 or 255).
        t: Constant from the reference paper numerical stability of similarity map.

    Returns:
        Gradient Magnitude Similarity Deviation between given tensors.

    References:
        Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013)
        https://arxiv.org/pdf/1308.3052.pdf
    """
    _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range))

    # Rescale
    x = x / float(data_range)
    y = y / float(data_range)

    num_channels = x.size(1)
    if num_channels == 3:
        x = rgb2yiq(x)[:, :1]
        y = rgb2yiq(y)[:, :1]
    up_pad = 0
    down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
    pad_to_use = [up_pad, down_pad, up_pad, down_pad]
    x = F.pad(x, pad=pad_to_use)
    y = F.pad(y, pad=pad_to_use)

    x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
    y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0)

    score = _gmsd(x=x, y=y, t=t)
    return _reduce(score, reduction)
Ejemplo n.º 25
0
def psnr(x: torch.Tensor,
         y: torch.Tensor,
         data_range: Union[int, float] = 1.0,
         reduction: str = 'mean',
         convert_to_greyscale: bool = False) -> torch.Tensor:
    r"""Compute Peak Signal-to-Noise Ratio for a batch of images.
    Supports both greyscale and color images with RGB channel order.

    Args:
        x: An input tensor. Shape :math:`(N, C, H, W)`.
        y: A target tensor. Shape :math:`(N, C, H, W)`.
        data_range: Maximum value range of images (usually 1.0 or 255).
        reduction: Specifies the reduction type:
            ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
        convert_to_greyscale: Convert RGB image to YCbCr format and computes PSNR
            only on luminance channel if `True`. Compute on all 3 channels otherwise.

    Returns:
        PSNR: Index of similarity betwen two images.

    References:
        https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
    """
    _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))

    # Constant for numerical stability
    EPS = 1e-8

    x = x / data_range
    y = y / data_range

    if (x.size(1) == 3) and convert_to_greyscale:
        # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B
        rgb_to_grey = torch.tensor([0.299, 0.587, 0.114]).view(1, -1, 1,
                                                               1).to(x)
        x = torch.sum(x * rgb_to_grey, dim=1, keepdim=True)
        y = torch.sum(y * rgb_to_grey, dim=1, keepdim=True)

    mse = torch.mean((x - y)**2, dim=[1, 2, 3])
    score: torch.Tensor = -10 * torch.log10(mse + EPS)

    return _reduce(score, reduction)
Ejemplo n.º 26
0
    def compute_metric(self, x_features: torch.Tensor,
                       y_features: torch.Tensor) -> torch.Tensor:
        r"""Compute MSID score between two sets of samples.

        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D_x)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D_y)`

        Returns:
            Scalar value of the distance between distributions.
        """
        _validate_input([x_features, y_features],
                        dim_range=(2, 2),
                        size_range=(1, 2))
        normed_msid_x = _msid_descriptor(
            x_features.detach().cpu().numpy(),
            ts=self.ts,
            k=self.k,
            m=self.m,
            niters=self.niters,
            rademacher=self.rademacher,
            normalized_laplacian=self.normalized_laplacian,
            normalize=self.normalize)
        normed_msid_y = _msid_descriptor(
            y_features.detach().cpu().numpy(),
            ts=self.ts,
            k=self.k,
            m=self.m,
            niters=self.niters,
            rademacher=self.rademacher,
            normalized_laplacian=self.normalized_laplacian,
            normalize=self.normalize)

        c = np.exp(-2 * (self.ts + 1 / self.ts))
        if self.msid_mode == 'l2':
            score = np.linalg.norm(normed_msid_x - normed_msid_y)
        elif self.msid_mode == 'max':
            score = np.amax(c * np.abs(normed_msid_x - normed_msid_y))
        else:
            raise ValueError('Mode must be in {`l2`, `max`}')

        return torch.tensor(score, device=x_features.device)
Ejemplo n.º 27
0
    def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor:
        r"""Compute IS.

        Both features should have shape (N_samples, encoder_dim).

        Args:
            x_features: Samples from data distribution. Shape :math:`(N_x, D)`
            y_features: Samples from data distribution. Shape :math:`(N_y, D)`

        Returns:
            L1 or L2 distance between scores for datasets :math:`x` and :math:`y`.
        """
        _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(0, 2))
        x_is, _ = inception_score(x_features, num_splits=self.num_splits)
        y_is, _ = inception_score(y_features, num_splits=self.num_splits)
        if self.distance == 'l1':
            return torch.dist(x_is, y_is, 1)
        elif self.distance == 'l2':
            return torch.dist(x_is, y_is, 2)
        else:
            raise ValueError("Distance should be one of {`l1`, `l2`}")
Ejemplo n.º 28
0
    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        r"""
        Computation of PieAPP  between feature representations of prediction and target tensors.

        Args:
            prediction: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
            target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W).
        """
        _validate_input(
            input_tensors=(prediction, target), allow_5d=False, allow_negative=True, data_range=self.data_range)
        prediction, target = _adjust_dimensions(input_tensors=(prediction, target))

        N, C, _, _ = prediction.shape
        if C == 1:
            prediction = prediction.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
            warnings.warn('The original PieAPP supports only RGB images.'
                          'The input images were converted to RGB by copying the grey channel 3 times.')

        self.model.to(device=prediction.device)
        prediction_features, prediction_weights = self.get_features(prediction)
        target_features, target_weights = self.get_features(target)

        distances, weights = self.model.compute_difference(
            target_features - prediction_features,
            target_weights - prediction_weights
        )

        distances = distances.reshape(N, -1)
        weights = weights.reshape(N, -1)

        # Scale scores, then average across patches
        loss = torch.stack([(d * w).sum() / w.sum() for d, w in zip(distances, weights)])

        if self.reduction == 'none':
            return loss

        return {'mean': loss.mean,
                'sum': loss.sum
                }[self.reduction](dim=0)
Ejemplo n.º 29
0
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        r"""
        Computation of PieAPP  between feature representations of prediction :math:`x` and target :math:`y` tensors.

        Args:
            x: An input tensor. Shape :math:`(N, C, H, W)`.
            y: A target tensor. Shape :math:`(N, C, H, W)`.

        Returns:
            Perceptual Image-Error Assessment through Pairwise Preference
        """
        _validate_input([x, y],
                        dim_range=(4, 4),
                        data_range=(0, self.data_range))

        N, C, _, _ = x.shape
        if C == 1:
            x = x.repeat(1, 3, 1, 1)
            y = y.repeat(1, 3, 1, 1)
            warnings.warn(
                'The original PieAPP supports only RGB images.'
                'The input images were converted to RGB by copying the grey channel 3 times.'
            )

        self.model.to(device=x.device)
        x_features, x_weights = self.get_features(x)
        y_features, y_weights = self.get_features(y)

        distances, weights = self.model.compute_difference(
            y_features - x_features, y_weights - x_weights)

        distances = distances.reshape(N, -1)
        weights = weights.reshape(N, -1)

        # Scale scores, then average across patches
        loss = torch.stack([(d * w).sum() / w.sum()
                            for d, w in zip(distances, weights)])

        return _reduce(loss, self.reduction)
Ejemplo n.º 30
0
def haarpsi(x: torch.Tensor, y: torch.Tensor, reduction: Optional[str] = 'mean',
            data_range: Union[int, float] = 1., scales: int = 3, subsample: bool = True,
            c: float = 30.0, alpha: float = 4.2) -> torch.Tensor:
    r"""Compute Haar Wavelet-Based Perceptual Similarity
    Input can by greyscale tensor of colour image with RGB channels order.
    Args:
        x: Tensor of shape :math:`(N, C, H, W)` holding an distorted image.
        y: Tensor of shape :math:`(N, C, H, W)` holding an target image
        reduction: Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        data_range: The difference between the maximum and minimum of the pixel value,
            i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1.
            The pixel value interval of both input and output should remain the same.
        scales: Number of Haar wavelets used for image decomposition.
        subsample: Flag to apply average pooling before HaarPSI computation. See [1] for details.
        c: Constant from the paper. See [1] for details
        alpha: Exponent used for similarity maps weightning. See [1] for details

    Returns:
        HaarPSI : Wavelet-Based Perceptual Similarity between two tensors
    
    References:
        [1] R. Reisenhofer, S. Bosse, G. Kutyniok & T. Wiegand (2017)
            'A Haar Wavelet-Based Perceptual Similarity Index for Image Quality Assessment'
            http://www.math.uni-bremen.de/cda/HaarPSI/publications/HaarPSI_preprint_v4.pdf
        [2] Code from authors on MATLAB and Python
            https://github.com/rgcda/haarpsi
    """

    _validate_input(input_tensors=(x, y), allow_5d=False, scale_weights=None)
    x, y = _adjust_dimensions(input_tensors=(x, y))

    # Assert minimal image size
    kernel_size = 2 ** (scales + 1)
    if x.size(-1) < kernel_size or x.size(-2) < kernel_size:
        raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. '
                         f'Kernel size: {kernel_size}')

    # Scale images to [0, 255] range as in the paper
    x = x * 255.0 / float(data_range)
    y = y * 255.0 / float(data_range)

    num_channels = x.size(1)
    # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ
    if num_channels == 3:
        x_yiq = rgb2yiq(x)
        y_yiq = rgb2yiq(y)
    else:
        x_yiq = x
        y_yiq = y

    # Downscale input to simulates the typical distance between an image and its viewer.
    if subsample:
        up_pad = 0
        down_pad = max(x.shape[2] % 2, x.shape[3] % 2)
        pad_to_use = [up_pad, down_pad, up_pad, down_pad]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)

        x_yiq = F.avg_pool2d(x_yiq, kernel_size=2, stride=2, padding=0)
        y_yiq = F.avg_pool2d(y_yiq, kernel_size=2, stride=2, padding=0)
    
    # Haar wavelet decomposition
    coefficients_x, coefficients_y = [], []
    for scale in range(scales):
        kernel_size = 2 ** (scale + 1)
        kernels = torch.stack([haar_filter(kernel_size), haar_filter(kernel_size).transpose(-1, -2)])
    
        # Assymetrical padding due to even kernel size. Matches MATLAB conv2(A, B, 'same')
        upper_pad = kernel_size // 2 - 1
        bottom_pad = kernel_size // 2
        pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad]
        coeff_x = torch.nn.functional.conv2d(F.pad(x_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(x))
        coeff_y = torch.nn.functional.conv2d(F.pad(y_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(y))
    
        coefficients_x.append(coeff_x)
        coefficients_y.append(coeff_y)

    # Shape [B x {scales * 2} x H x W]
    coefficients_x = torch.cat(coefficients_x, dim=1)
    coefficients_y = torch.cat(coefficients_y, dim=1)

    # Low-frequency coefficients used as weights
    # Shape [B x 2 x H x W]
    weights = torch.max(torch.abs(coefficients_x[:, 4:]), torch.abs(coefficients_y[:, 4:]))
    
    # High-frequency coefficients used for similarity computation in 2 orientations (horizontal and vertical)
    sim_map = []
    for orientation in range(2):
        magnitude_x = torch.abs(coefficients_x[:, (orientation, orientation + 2)])
        magnitude_y = torch.abs(coefficients_y[:, (orientation, orientation + 2)])
        sim_map.append(similarity_map(magnitude_x, magnitude_y, constant=c).sum(dim=1, keepdims=True) / 2)

    if num_channels == 3:
        pad_to_use = [0, 1, 0, 1]
        x_yiq = F.pad(x_yiq, pad=pad_to_use)
        y_yiq = F.pad(y_yiq, pad=pad_to_use)
        coefficients_x_iq = torch.abs(F.avg_pool2d(x_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
        coefficients_y_iq = torch.abs(F.avg_pool2d(y_yiq[:, 1:], kernel_size=2, stride=1, padding=0))
    
        # Compute weights and simmilarity
        weights = torch.cat([weights, weights.mean(dim=1, keepdims=True)], dim=1)
        sim_map.append(
            similarity_map(coefficients_x_iq, coefficients_y_iq, constant=c).sum(dim=1, keepdims=True) / 2)

    sim_map = torch.cat(sim_map, dim=1)
    
    # Calculate the final score
    eps = torch.finfo(sim_map.dtype).eps
    score = (((sim_map * alpha).sigmoid() * weights).sum(dim=[1, 2, 3]) + eps) /\
        (torch.sum(weights, dim=[1, 2, 3]) + eps)
    # Logit of score
    score = (torch.log(score / (1 - score)) / alpha) ** 2

    if reduction == 'none':
        return score

    return {'mean': score.mean,
            'sum': score.sum
            }[reduction](dim=0)