def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor: r""" Fits multivariate Gaussians: :math:`X \sim \mathcal{N}(\mu_x, \sigma_x)` and :math:`Y \sim \mathcal{N}(\mu_y, \sigma_y)` to image stacks. Then computes FID as :math:`d^2 = ||\mu_x - \mu_y||^2 + Tr(\sigma_x + \sigma_y - 2\sqrt{\sigma_x \sigma_y})`. Args: x_features: Samples from data distribution. Shape :math:`(N_x, D)` y_features: Samples from data distribution. Shape :math:`(N_y, D)` Returns: The Frechet Distance. """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2)) # GPU -> CPU mu_x, sigma_x = _compute_statistics( x_features.detach().to(dtype=torch.float64)) mu_y, sigma_y = _compute_statistics( y_features.detach().to(dtype=torch.float64)) score = _compute_fid(mu_x, sigma_x, mu_y, sigma_y) return score
def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor: r"""Implements Algorithm 2 from the paper. Args: x_features: Samples from data distribution. Shape :math:`(N_x, D)` y_features: Samples from data distribution. Shape :math:`(N_y, D)` Returns: score: Scalar value of the distance between distributions. """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2)) with Pool(self.num_workers) as p: self.features = x_features.detach().cpu().numpy() pool_results = p.map(self._relative_living_times, range(self.num_iters)) mean_rlt_x = np.vstack(pool_results).mean(axis=0) self.features = y_features.detach().cpu().numpy() pool_results = p.map(self._relative_living_times, range(self.num_iters)) mean_rlt_y = np.vstack(pool_results).mean(axis=0) score = np.sum((mean_rlt_x - mean_rlt_y)**2) return torch.tensor(score, device=x_features.device) * 1000
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r"""Computation of Content loss between feature representations of prediction :math:`x` and target :math:`y` tensors. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. Returns: Content loss between feature representations """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1)) self.model.to(x) x_features = self.get_features(x) y_features = self.get_features(y) distances = self.compute_distance(x_features, y_features) # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1) return _reduce(loss, self.reduction)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r""" Forward pass a batch of square patches with shape :math:`(N, C, F, F)`. Returns: features: Concatenation of model features from different scales x11: Outputs of the last convolutional layer used as weights """ _validate_input([ x, ], dim_range=(4, 4), data_range=(0, -1)) assert x.shape[2] == x.shape[3] == self.FEATURES, \ f"Expected square input with shape {self.FEATURES, self.FEATURES}, got {x.shape}" # conv1 -> relu -> conv2 -> relu -> pool -> conv3 -> relu x3 = F.relu( self.conv3(self.pool(F.relu(self.conv2(F.relu(self.conv1(x))))))) # conv4 -> relu -> pool -> conv5 -> relu x5 = F.relu(self.conv5(self.pool(F.relu(self.conv4(x3))))) # conv6 -> relu -> pool -> conv7 -> relu x7 = F.relu(self.conv7(self.pool(F.relu(self.conv6(x5))))) # conv8 -> relu -> pool -> conv9 -> relu x9 = F.relu(self.conv9(self.pool(F.relu(self.conv8(x7))))) # conv10 -> relu -> pool1-> conv11 -> relU x11 = self.flatten( F.relu(self.conv11(self.pool(F.relu(self.conv10(x9)))))) # flatten and concatenate features = torch.cat((self.flatten(x3), self.flatten(x5), self.flatten(x7), self.flatten(x9), x11), dim=1) return features, x11
def test_breaks_if_too_many_tensors_provided(tensor_2d: torch.Tensor) -> None: max_number_of_tensors = 2 for n_tensors in range(max_number_of_tensors + 1, (max_number_of_tensors + 1) * 10): inp = tuple(tensor_2d.clone() for _ in range(n_tensors)) with pytest.raises(AssertionError): _validate_input(inp, allow_5d=False)
def test_works_on_single_5d_tensor(tensor_5d: torch.Tensor) -> None: try: _validate_input([ tensor_5d, ], dim_range=(5, 5)) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}")
def total_variation(x: torch.Tensor, reduction: str = 'mean', norm_type: str = 'l2') -> torch.Tensor: r"""Compute Total Variation metric Args: x: Tensor. Shape :math:`(N, C, H, W)`. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic or anisotropic. Returns: score : Total variation of a given tensor References: https://www.wikiwand.com/en/Total_variation_denoising https://remi.flamary.com/demos/proxtv.html """ _validate_input([x, ], dim_range=(4, 4), data_range=(0, -1)) if norm_type == 'l1': w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]), dim=[1, 2, 3]) h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]), dim=[1, 2, 3]) score = (h_variance + w_variance) elif norm_type == 'l2': w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3]) h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3]) score = torch.sqrt(h_variance + w_variance) elif norm_type == 'l2_squared': w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3]) h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3]) score = (h_variance + w_variance) else: raise ValueError("Incorrect norm type, should be one of {'l1', 'l2', 'l2_squared'}") return _reduce(score, reduction)
def test_breaks_if_scale_weight_wrong_n_dims_provided( tensor_2d: torch.Tensor) -> None: wrong_scale_weights = tensor_2d.clone() with pytest.raises(AssertionError): _validate_input(tensor_2d, allow_5d=False, scale_weights=wrong_scale_weights)
def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor: r""" Fits multivariate Gaussians: X ~ N(mu_x, sigma_x) and Y ~ N(mu_y, sigma_y) to image stacks. Then computes FID as d^2 = ||mu_x - mu_y||^2 + Tr(sigma_x + sigma_y - 2*sqrt(sigma_x * sigma_y)). Args: x_features: Samples from data distribution. Shape :math:`(N_x, D)` y_features: Samples from data distribution. Shape :math:`(N_y, D)` Returns: -- : The Frechet Distance. """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2)) # GPU -> CPU mu_x, sigma_x = _compute_statistics( x_features.detach().to(dtype=torch.float64)) mu_y, sigma_y = _compute_statistics( y_features.detach().to(dtype=torch.float64)) score = _compute_fid(mu_x, sigma_x, mu_y, sigma_y) return score
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r"""Computation of Content loss between feature representations of prediction and target tensors. Args: prediction: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). """ _validate_input(input_tensors=(prediction, target), allow_5d=False, allow_negative=True) prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) self.model.to(prediction) prediction_features = self.get_features(prediction) target_features = self.get_features(target) distances = self.compute_distance(prediction_features, target_features) # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1) if self.reduction == 'none': return loss return {'mean': loss.mean, 'sum': loss.sum}[self.reduction](dim=0)
def compute_metric(self, real_features: torch.Tensor, fake_features: torch.Tensor) \ -> Tuple[torch.Tensor, torch.Tensor]: r"""Creates non-parametric representations of the manifolds of real and generated data and computes the precision and recall between them. Args: real_features: Samples from data distribution. Shape :math:`(N_x, D)` fake_features: Samples from fake distribution. Shape :math:`(N_x, D)` Returns: Scalar value of the precision of the generated images. Scalar value of the recall of the generated images. """ _validate_input([real_features, fake_features], dim_range=(2, 2), size_range=(1, 2)) real_nearest_neighbour_distances = _compute_nearest_neighbour_distances( real_features, self.nearest_k) fake_nearest_neighbour_distances = _compute_nearest_neighbour_distances( fake_features, self.nearest_k) distance_real_fake = _compute_pairwise_distance( real_features, fake_features) precision = (distance_real_fake < real_nearest_neighbour_distances.unsqueeze(1)).any( dim=0).float().mean() recall = (distance_real_fake < fake_nearest_neighbour_distances.unsqueeze(0)).any( dim=1).float().mean() return precision, recall
def brisque(x: torch.Tensor, kernel_size: int = 7, kernel_sigma: float = 7 / 6, data_range: Union[int, float] = 1., reduction: str = 'mean', interpolation: str = 'nearest') -> torch.Tensor: r"""Interface of BRISQUE index. Args: x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). RGB channel order for colour images. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Maximum value range of input images (usually 1.0 or 255). reduction: Reduction over samples in batch: "mean"|"sum"|"none". interpolation: Interpolation to be used for scaling. Returns: Value of BRISQUE index. Note: The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation. Update the torch and torchvision to the latest versions. References: .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf """ if '1.5.0' in torch.__version__: warnings.warn( f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.' f'Update torch to the latest version to access full functionality of the BRIQSUE.' f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and' f'https://github.com/pytorch/pytorch/issues/38869.') _validate_input(input_tensors=x, allow_5d=False, kernel_size=kernel_size) x = _adjust_dimensions(input_tensors=x) assert data_range >= x.max( ), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.' x = x * 255. / data_range if x.size(1) == 3: x = rgb2yiq(x)[:, :1] features = [] num_of_scales = 2 for _ in range(num_of_scales): features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma)) x = F.interpolate(x, size=(x.size(2) // 2, x.size(3) // 2), mode=interpolation) features = torch.cat(features, dim=-1) scaled_features = _scale_features(features) score = _score_svr(scaled_features) if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
def brisque(x: torch.Tensor, kernel_size: int = 7, kernel_sigma: float = 7 / 6, data_range: Union[int, float] = 1., reduction: str = 'mean', interpolation: str = 'nearest') -> torch.Tensor: r"""Interface of BRISQUE index. Supports greyscale and colour images with RGB channel order. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` interpolation: Interpolation to be used for scaling. Returns: Value of BRISQUE index. Note: The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation. Update the torch and torchvision to the latest versions. References: .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf """ if '1.5.0' in torch.__version__: warnings.warn( f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.' f'Update torch to the latest version to access full functionality of the BRIQSUE.' f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and' f'https://github.com/pytorch/pytorch/issues/38869.') assert kernel_size % 2 == 1, f'Kernel size must be odd, got [{kernel_size}]' _validate_input([ x, ], dim_range=(4, 4), data_range=(0, data_range)) x = x / data_range * 255 if x.size(1) == 3: x = rgb2yiq(x)[:, :1] features = [] num_of_scales = 2 for _ in range(num_of_scales): features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma)) x = F.interpolate(x, size=(x.size(2) // 2, x.size(3) // 2), mode=interpolation) features = torch.cat(features, dim=-1) scaled_features = _scale_features(features) score = _score_svr(scaled_features) return _reduce(score, reduction)
def gmsd( x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean', data_range: Union[int, float] = 1., t: float = 170 / (255.**2) ) -> torch.Tensor: r"""Compute Gradient Magnitude Similarity Deviation. Inputs supposed to be in range [0, data_range] with RGB channels order for colour images. Args: x: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). y: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. data_range: The difference between the maximum and minimum of the pixel value, i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. The pixel value interval of both input and output should remain the same. t: Constant from the reference paper numerical stability of similarity map. Returns: gmsd : Gradient Magnitude Similarity Deviation between given tensors. References: Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013) https://arxiv.org/pdf/1308.3052.pdf """ _validate_input(input_tensors=(x, y), allow_5d=False, scale_weights=None, data_range=data_range) x, y = _adjust_dimensions(input_tensors=(x, y)) # Rescale x = x / data_range y = y / data_range num_channels = x.size(1) if num_channels == 3: x = rgb2yiq(x)[:, :1] y = rgb2yiq(y)[:, :1] up_pad = 0 down_pad = max(x.shape[2] % 2, x.shape[3] % 2) pad_to_use = [up_pad, down_pad, up_pad, down_pad] x = F.pad(x, pad=pad_to_use) y = F.pad(y, pad=pad_to_use) x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0) score = _gmsd(x=x, y=y, t=t) if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
def test_breaks_if_not_supported_data_types_provided() -> None: inputs_of_wrong_types = [[], (), {}, 42, '42', True, 42., np.array([42]), None] for inp in inputs_of_wrong_types: with pytest.raises(AssertionError): _validate_input([ inp, ])
def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) \ -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Computes KID (polynomial MMD) for given sets of features, obtained from Inception net or any other feature extractor. Samples must be in range [0, 1]. Args: x_features: Samples from data distribution. Shape :math:`(N_x, D)` y_features: Samples from data distribution. Shape :math:`(N_y, D)` Returns: KID score and variance (optional). """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2)) var_at_m = min(x_features.size(0), y_features.size(0)) if self.subset_size is None: subset_size = x_features.size(0) else: subset_size = self.subset_size results = [] for _ in range(self.n_subsets): x_subset = x_features[torch.randperm(len(x_features))[:subset_size]] y_subset = y_features[torch.randperm(len(y_features))[:subset_size]] # use k(x, y) = (gamma <x, y> + coef0)^degree # default gamma is 1 / dim K_XX = _polynomial_kernel( x_subset, None, degree=self.degree, gamma=self.gamma, coef0=self.coef0) K_YY = _polynomial_kernel( y_subset, None, degree=self.degree, gamma=self.gamma, coef0=self.coef0) K_XY = _polynomial_kernel( x_subset, y_subset, degree=self.degree, gamma=self.gamma, coef0=self.coef0) out = _mmd2_and_variance(K_XX, K_XY, K_YY, var_at_m=var_at_m, ret_var=self.ret_var) results.append(out) if self.ret_var: score = torch.mean(torch.stack([p[0] for p in results], dim=0)) variance = torch.mean(torch.stack([p[1] for p in results], dim=0)) return (score, variance) else: score = torch.mean(torch.stack(results, dim=0)) return score
def gmsd( prediction: torch.Tensor, target: torch.Tensor, reduction: Optional[str] = 'mean', data_range: Union[int, float] = 1., t: float = 170 / (255.**2) ) -> torch.Tensor: r"""Compute Gradient Magnitude Similarity Deviation Both inputs supposed to be in range [0, 1] with RGB order. Args: prediction: Tensor of shape :math:`(N, C, H, W)` holding an distorted image. target: Tensor of shape :math:`(N, C, H, W)` holding an target image reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. data_range: The difference between the maximum and minimum of the pixel value, i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. The pixel value interval of both input and output should remain the same. t: Constant from the reference paper numerical stability of similarity map. Returns: gmsd : Gradient Magnitude Similarity Deviation between given tensors. References: https://arxiv.org/pdf/1308.3052.pdf """ _validate_input(input_tensors=(prediction, target), allow_5d=False, scale_weights=None) prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) prediction = prediction / float(data_range) target = target / float(data_range) num_channels = prediction.size(1) if num_channels == 3: prediction = rgb2yiq(prediction)[:, :1] target = rgb2yiq(target)[:, :1] up_pad = 0 down_pad = max(prediction.shape[2] % 2, prediction.shape[3] % 2) pad_to_use = [up_pad, down_pad, up_pad, down_pad] prediction = F.pad(prediction, pad=pad_to_use) target = F.pad(target, pad=pad_to_use) prediction = F.avg_pool2d(prediction, kernel_size=2, stride=2, padding=0) target = F.avg_pool2d(target, kernel_size=2, stride=2, padding=0) score = _gmsd(prediction=prediction, target=target, t=t) if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor: r"""Compute MSID score between two sets of samples. Args: x_features: Samples from data distribution. Shape :math:`(N_x, D)` y_features: Samples from data distribution. Shape :math:`(N_y, D)` ts: Temperature values. k: Number of neighbours for graph construction. m: Lanczos steps in SLQ. niters: Number of starting random vectors for SLQ. rademacher: True to use Rademacher distribution, False - standard normal for random vectors in Hutchinson. msid_mode: 'l2' to compute the l2 norm of the distance between `msid1` and `msid2`; 'max' to find the maximum abosulute difference between two descriptors over temperature normalized_laplacian: if True, use normalized Laplacian. normalize: 'empty' for average heat kernel (corresponds to the empty graph normalization of NetLSD), 'complete' for the complete, 'er' for erdos-renyi normalization, 'none' for no normalization Returns: score: Scalar value of the distance between distributions. """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2)) normed_msid_x = _msid_descriptor( x_features.detach().cpu().numpy(), ts=self.ts, k=self.k, m=self.m, niters=self.niters, rademacher=self.rademacher, normalized_laplacian=self.normalized_laplacian, normalize=self.normalize) normed_msid_y = _msid_descriptor( y_features.detach().cpu().numpy(), ts=self.ts, k=self.k, m=self.m, niters=self.niters, rademacher=self.rademacher, normalized_laplacian=self.normalized_laplacian, normalize=self.normalize) c = np.exp(-2 * (self.ts + 1 / self.ts)) if self.msid_mode == 'l2': score = np.linalg.norm(normed_msid_x - normed_msid_y) elif self.msid_mode == 'max': score = np.amax(c * np.abs(normed_msid_x - normed_msid_y)) else: raise ValueError('Mode must be in {`l2`, `max`}') return torch.tensor(score, device=x_features.device)
def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False, k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Interface of Structural Similarity (SSIM) index. Args: x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. y: Batch of images. Required to be 2D (H, W), 3D (C,H,W) 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Value range of input images (usually 1.0 or 255). reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, full: Return cs map or not. k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. References: .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, :DOI:`10.1109/TIP.2003.819861` """ _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=None) x, y = _adjust_dimensions(input_tensors=(x, y)) if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor): x = x.type(torch.float32) y = y.type(torch.float32) kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) ssim_val = ssim_map.mean(1) cs = cs_map.mean(1) if reduction != 'none': reduction_operation = {'mean': torch.mean, 'sum': torch.sum} ssim_val = reduction_operation[reduction](ssim_val, dim=0) cs = reduction_operation[reduction](cs, dim=0) if full: return ssim_val, cs return ssim_val
def test_works_on_two_not_5d_tensors(tensor_1d: torch.Tensor) -> None: tensor = tensor_1d.clone() max_num_dims = 10 for _ in range(max_num_dims): another_tensor = tensor.clone() if 1 < tensor.dim() < 5: try: _validate_input([tensor, another_tensor], dim_range=(2, 4)) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}") else: with pytest.raises(AssertionError): _validate_input([tensor, another_tensor], dim_range=(2, 4)) tensor.unsqueeze_(0)
def test_works_on_single_not_5d_tensor(tensor_1d: torch.Tensor) -> None: tensor = tensor_1d.clone() # 1D -> max_num_dims max_num_dims = 10 for _ in range(max_num_dims): if 1 < tensor.dim() < 5: try: _validate_input(tensor, allow_5d=False) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}") else: with pytest.raises(AssertionError): _validate_input(tensor, allow_5d=False) tensor.unsqueeze_(0)
def psnr(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1.0, reduction: str = 'mean', convert_to_greyscale: bool = False) -> torch.Tensor: r"""Compute Peak Signal-to-Noise Ratio for a batch of images. Supports both greyscale and color images with RGB channel order. Args: x: Predicted images set :math:`x`. Shape (H, W), (C, H, W) or (N, C, H, W). y: Target images set :math:`y`. Shape (H, W), (C, H, W) or (N, C, H, W). data_range: Value range of input images (usually 1.0 or 255). Default: 1.0 reduction: Reduction over samples in batch: "mean"|"sum"|"none" convert_to_greyscale: Convert RGB image to YCbCr format and computes PSNR only on luminance channel if `True`. Compute on all 3 channels otherwise. Returns: PSNR: Index of similarity betwen two images. References: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio """ _validate_input((x, y), allow_5d=False, data_range=data_range) x, y = _adjust_dimensions(input_tensors=(x, y)) # Constant for numerical stability EPS = 1e-8 x = x / data_range y = y / data_range if (x.size(1) == 3) and convert_to_greyscale: # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B rgb_to_grey = torch.tensor([0.299, 0.587, 0.114]).view(1, -1, 1, 1).to(x) x = torch.sum(x * rgb_to_grey, dim=1, keepdim=True) y = torch.sum(y * rgb_to_grey, dim=1, keepdim=True) mse = torch.mean((x - y)**2, dim=[1, 2, 3]) score: torch.Tensor = -10 * torch.log10(mse + EPS) if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
def total_variation(x: torch.Tensor, reduction: str = 'mean', norm_type: str = 'l2') -> torch.Tensor: r"""Compute Total Variation metric Args: x: Tensor with shape (N, C, H, W). reduction: Reduction over samples in batch: "mean"|"sum"|"none" norm_type: {'l1', 'l2', 'l2_squared'}, defines which type of norm to implement, isotropic or anisotropic. Returns: score : Total variation of a given tensor References: https://www.wikiwand.com/en/Total_variation_denoising https://remi.flamary.com/demos/proxtv.html """ _validate_input(x, allow_5d=False) x = _adjust_dimensions(x) if norm_type == 'l1': w_variance = torch.sum(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]), dim=[1, 2, 3]) h_variance = torch.sum(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]), dim=[1, 2, 3]) score = (h_variance + w_variance) elif norm_type == 'l2': w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3]) h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3]) score = torch.sqrt(h_variance + w_variance) elif norm_type == 'l2_squared': w_variance = torch.sum(torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2), dim=[1, 2, 3]) h_variance = torch.sum(torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2), dim=[1, 2, 3]) score = (h_variance + w_variance) else: raise ValueError( "Incorrect reduction type, should be one of {'l1', 'l2', 'l2_squared'}" ) if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum}[reduction](dim=0)
def gmsd( x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean', data_range: Union[int, float] = 1., t: float = 170 / (255.**2) ) -> torch.Tensor: r"""Compute Gradient Magnitude Similarity Deviation. Supports greyscale and colour images with RGB channel order. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` data_range: Maximum value range of images (usually 1.0 or 255). t: Constant from the reference paper numerical stability of similarity map. Returns: Gradient Magnitude Similarity Deviation between given tensors. References: Wufeng Xue et al. Gradient Magnitude Similarity Deviation (2013) https://arxiv.org/pdf/1308.3052.pdf """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, data_range)) # Rescale x = x / float(data_range) y = y / float(data_range) num_channels = x.size(1) if num_channels == 3: x = rgb2yiq(x)[:, :1] y = rgb2yiq(y)[:, :1] up_pad = 0 down_pad = max(x.shape[2] % 2, x.shape[3] % 2) pad_to_use = [up_pad, down_pad, up_pad, down_pad] x = F.pad(x, pad=pad_to_use) y = F.pad(y, pad=pad_to_use) x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=0) score = _gmsd(x=x, y=y, t=t) return _reduce(score, reduction)
def psnr(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float] = 1.0, reduction: str = 'mean', convert_to_greyscale: bool = False) -> torch.Tensor: r"""Compute Peak Signal-to-Noise Ratio for a batch of images. Supports both greyscale and color images with RGB channel order. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). reduction: Specifies the reduction type: ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` convert_to_greyscale: Convert RGB image to YCbCr format and computes PSNR only on luminance channel if `True`. Compute on all 3 channels otherwise. Returns: PSNR: Index of similarity betwen two images. References: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio """ _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) # Constant for numerical stability EPS = 1e-8 x = x / data_range y = y / data_range if (x.size(1) == 3) and convert_to_greyscale: # Convert RGB image to YCbCr and take luminance: Y = 0.299 R + 0.587 G + 0.114 B rgb_to_grey = torch.tensor([0.299, 0.587, 0.114]).view(1, -1, 1, 1).to(x) x = torch.sum(x * rgb_to_grey, dim=1, keepdim=True) y = torch.sum(y * rgb_to_grey, dim=1, keepdim=True) mse = torch.mean((x - y)**2, dim=[1, 2, 3]) score: torch.Tensor = -10 * torch.log10(mse + EPS) return _reduce(score, reduction)
def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor: r"""Compute MSID score between two sets of samples. Args: x_features: Samples from data distribution. Shape :math:`(N_x, D_x)` y_features: Samples from data distribution. Shape :math:`(N_y, D_y)` Returns: Scalar value of the distance between distributions. """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(1, 2)) normed_msid_x = _msid_descriptor( x_features.detach().cpu().numpy(), ts=self.ts, k=self.k, m=self.m, niters=self.niters, rademacher=self.rademacher, normalized_laplacian=self.normalized_laplacian, normalize=self.normalize) normed_msid_y = _msid_descriptor( y_features.detach().cpu().numpy(), ts=self.ts, k=self.k, m=self.m, niters=self.niters, rademacher=self.rademacher, normalized_laplacian=self.normalized_laplacian, normalize=self.normalize) c = np.exp(-2 * (self.ts + 1 / self.ts)) if self.msid_mode == 'l2': score = np.linalg.norm(normed_msid_x - normed_msid_y) elif self.msid_mode == 'max': score = np.amax(c * np.abs(normed_msid_x - normed_msid_y)) else: raise ValueError('Mode must be in {`l2`, `max`}') return torch.tensor(score, device=x_features.device)
def compute_metric(self, x_features: torch.Tensor, y_features: torch.Tensor) -> torch.Tensor: r"""Compute IS. Both features should have shape (N_samples, encoder_dim). Args: x_features: Samples from data distribution. Shape :math:`(N_x, D)` y_features: Samples from data distribution. Shape :math:`(N_y, D)` Returns: L1 or L2 distance between scores for datasets :math:`x` and :math:`y`. """ _validate_input([x_features, y_features], dim_range=(2, 2), size_range=(0, 2)) x_is, _ = inception_score(x_features, num_splits=self.num_splits) y_is, _ = inception_score(y_features, num_splits=self.num_splits) if self.distance == 'l1': return torch.dist(x_is, y_is, 1) elif self.distance == 'l2': return torch.dist(x_is, y_is, 2) else: raise ValueError("Distance should be one of {`l1`, `l2`}")
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r""" Computation of PieAPP between feature representations of prediction and target tensors. Args: prediction: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). target: Tensor with shape (H, W), (C, H, W) or (N, C, H, W). """ _validate_input( input_tensors=(prediction, target), allow_5d=False, allow_negative=True, data_range=self.data_range) prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) N, C, _, _ = prediction.shape if C == 1: prediction = prediction.repeat(1, 3, 1, 1) target = target.repeat(1, 3, 1, 1) warnings.warn('The original PieAPP supports only RGB images.' 'The input images were converted to RGB by copying the grey channel 3 times.') self.model.to(device=prediction.device) prediction_features, prediction_weights = self.get_features(prediction) target_features, target_weights = self.get_features(target) distances, weights = self.model.compute_difference( target_features - prediction_features, target_weights - prediction_weights ) distances = distances.reshape(N, -1) weights = weights.reshape(N, -1) # Scale scores, then average across patches loss = torch.stack([(d * w).sum() / w.sum() for d, w in zip(distances, weights)]) if self.reduction == 'none': return loss return {'mean': loss.mean, 'sum': loss.sum }[self.reduction](dim=0)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r""" Computation of PieAPP between feature representations of prediction :math:`x` and target :math:`y` tensors. Args: x: An input tensor. Shape :math:`(N, C, H, W)`. y: A target tensor. Shape :math:`(N, C, H, W)`. Returns: Perceptual Image-Error Assessment through Pairwise Preference """ _validate_input([x, y], dim_range=(4, 4), data_range=(0, self.data_range)) N, C, _, _ = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) y = y.repeat(1, 3, 1, 1) warnings.warn( 'The original PieAPP supports only RGB images.' 'The input images were converted to RGB by copying the grey channel 3 times.' ) self.model.to(device=x.device) x_features, x_weights = self.get_features(x) y_features, y_weights = self.get_features(y) distances, weights = self.model.compute_difference( y_features - x_features, y_weights - x_weights) distances = distances.reshape(N, -1) weights = weights.reshape(N, -1) # Scale scores, then average across patches loss = torch.stack([(d * w).sum() / w.sum() for d, w in zip(distances, weights)]) return _reduce(loss, self.reduction)
def haarpsi(x: torch.Tensor, y: torch.Tensor, reduction: Optional[str] = 'mean', data_range: Union[int, float] = 1., scales: int = 3, subsample: bool = True, c: float = 30.0, alpha: float = 4.2) -> torch.Tensor: r"""Compute Haar Wavelet-Based Perceptual Similarity Input can by greyscale tensor of colour image with RGB channels order. Args: x: Tensor of shape :math:`(N, C, H, W)` holding an distorted image. y: Tensor of shape :math:`(N, C, H, W)` holding an target image reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. data_range: The difference between the maximum and minimum of the pixel value, i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. The pixel value interval of both input and output should remain the same. scales: Number of Haar wavelets used for image decomposition. subsample: Flag to apply average pooling before HaarPSI computation. See [1] for details. c: Constant from the paper. See [1] for details alpha: Exponent used for similarity maps weightning. See [1] for details Returns: HaarPSI : Wavelet-Based Perceptual Similarity between two tensors References: [1] R. Reisenhofer, S. Bosse, G. Kutyniok & T. Wiegand (2017) 'A Haar Wavelet-Based Perceptual Similarity Index for Image Quality Assessment' http://www.math.uni-bremen.de/cda/HaarPSI/publications/HaarPSI_preprint_v4.pdf [2] Code from authors on MATLAB and Python https://github.com/rgcda/haarpsi """ _validate_input(input_tensors=(x, y), allow_5d=False, scale_weights=None) x, y = _adjust_dimensions(input_tensors=(x, y)) # Assert minimal image size kernel_size = 2 ** (scales + 1) if x.size(-1) < kernel_size or x.size(-2) < kernel_size: raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. ' f'Kernel size: {kernel_size}') # Scale images to [0, 255] range as in the paper x = x * 255.0 / float(data_range) y = y * 255.0 / float(data_range) num_channels = x.size(1) # Convert RGB to YIQ color space https://en.wikipedia.org/wiki/YIQ if num_channels == 3: x_yiq = rgb2yiq(x) y_yiq = rgb2yiq(y) else: x_yiq = x y_yiq = y # Downscale input to simulates the typical distance between an image and its viewer. if subsample: up_pad = 0 down_pad = max(x.shape[2] % 2, x.shape[3] % 2) pad_to_use = [up_pad, down_pad, up_pad, down_pad] x_yiq = F.pad(x_yiq, pad=pad_to_use) y_yiq = F.pad(y_yiq, pad=pad_to_use) x_yiq = F.avg_pool2d(x_yiq, kernel_size=2, stride=2, padding=0) y_yiq = F.avg_pool2d(y_yiq, kernel_size=2, stride=2, padding=0) # Haar wavelet decomposition coefficients_x, coefficients_y = [], [] for scale in range(scales): kernel_size = 2 ** (scale + 1) kernels = torch.stack([haar_filter(kernel_size), haar_filter(kernel_size).transpose(-1, -2)]) # Assymetrical padding due to even kernel size. Matches MATLAB conv2(A, B, 'same') upper_pad = kernel_size // 2 - 1 bottom_pad = kernel_size // 2 pad_to_use = [upper_pad, bottom_pad, upper_pad, bottom_pad] coeff_x = torch.nn.functional.conv2d(F.pad(x_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(x)) coeff_y = torch.nn.functional.conv2d(F.pad(y_yiq[:, : 1], pad=pad_to_use, mode='constant'), kernels.to(y)) coefficients_x.append(coeff_x) coefficients_y.append(coeff_y) # Shape [B x {scales * 2} x H x W] coefficients_x = torch.cat(coefficients_x, dim=1) coefficients_y = torch.cat(coefficients_y, dim=1) # Low-frequency coefficients used as weights # Shape [B x 2 x H x W] weights = torch.max(torch.abs(coefficients_x[:, 4:]), torch.abs(coefficients_y[:, 4:])) # High-frequency coefficients used for similarity computation in 2 orientations (horizontal and vertical) sim_map = [] for orientation in range(2): magnitude_x = torch.abs(coefficients_x[:, (orientation, orientation + 2)]) magnitude_y = torch.abs(coefficients_y[:, (orientation, orientation + 2)]) sim_map.append(similarity_map(magnitude_x, magnitude_y, constant=c).sum(dim=1, keepdims=True) / 2) if num_channels == 3: pad_to_use = [0, 1, 0, 1] x_yiq = F.pad(x_yiq, pad=pad_to_use) y_yiq = F.pad(y_yiq, pad=pad_to_use) coefficients_x_iq = torch.abs(F.avg_pool2d(x_yiq[:, 1:], kernel_size=2, stride=1, padding=0)) coefficients_y_iq = torch.abs(F.avg_pool2d(y_yiq[:, 1:], kernel_size=2, stride=1, padding=0)) # Compute weights and simmilarity weights = torch.cat([weights, weights.mean(dim=1, keepdims=True)], dim=1) sim_map.append( similarity_map(coefficients_x_iq, coefficients_y_iq, constant=c).sum(dim=1, keepdims=True) / 2) sim_map = torch.cat(sim_map, dim=1) # Calculate the final score eps = torch.finfo(sim_map.dtype).eps score = (((sim_map * alpha).sigmoid() * weights).sum(dim=[1, 2, 3]) + eps) /\ (torch.sum(weights, dim=[1, 2, 3]) + eps) # Logit of score score = (torch.log(score / (1 - score)) / alpha) ** 2 if reduction == 'none': return score return {'mean': score.mean, 'sum': score.sum }[reduction](dim=0)