def test_reduce(): start_tensor = torch.rand(50, 40, 30) assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor)) assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor)) assert torch.allclose(reduce(start_tensor, 'none'), start_tensor) with pytest.raises(ValueError): reduce(start_tensor, 'error_reduction')
def _sam_compute( preds: Tensor, target: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Computes Spectral Angle Mapper. Args: preds: estimated image target: ground truth image reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied Example: >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) >>> preds, target = _sam_update(preds, target) >>> _sam_compute(preds, target) tensor(0.5943) """ dot_product = (preds * target).sum(dim=1) preds_norm = preds.norm(dim=1) target_norm = target.norm(dim=1) sam_score = torch.clamp(dot_product / (preds_norm * target_norm), -1, 1).acos() return reduce(sam_score, reduction)
def dice_score( preds: Tensor, target: Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> Tensor: """ Compute dice score from prediction scores Args: preds: estimated probabilities target: ground-truth labels bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation no_fg_score: score to return, if no foreground pixel was found in target reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied Return: Tensor containing dice score Example: >>> from torchmetrics.functional import dice_score >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> dice_score(pred, target) tensor(0.3333) """ num_classes = preds.shape[1] bg_inv = (1 - int(bg)) scores = torch.zeros(num_classes - bg_inv, device=preds.device, dtype=torch.float32) for i in range(bg_inv, num_classes): if not (target == i).any(): # no foreground class scores[i - bg_inv] += no_fg_score continue # TODO: rewrite to use general `stat_scores` tp, fp, _, fn, _ = _stat_scores(preds=preds, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) # nan result score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero( denom) else nan_score scores[i - bg_inv] += score_cls return reduce(scores, reduction=reduction)
def _ssim_compute( preds: Tensor, target: Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: str = "elementwise_mean", data_range: Optional[float] = None, k1: float = 0.01, k2: float = 0.03, ): if len(kernel_size) != 2 or len(sigma) != 2: raise ValueError( "Expected `kernel_size` and `sigma` to have the length of two." f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." ) if any(x % 2 == 0 or x <= 0 for x in kernel_size): raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") if any(y <= 0 for y in sigma): raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: data_range = max(preds.max() - preds.min(), target.max() - target.min()) c1 = pow(k1 * data_range, 2) c2 = pow(k2 * data_range, 2) device = preds.device channel = preds.size(1) dtype = preds.dtype kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) pad_w = (kernel_size[0] - 1) // 2 pad_h = (kernel_size[1] - 1) // 2 preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect') target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect') input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) outputs = F.conv2d(input_list, kernel, groups=channel) output_list = [outputs[x * preds.size(0):(x + 1) * preds.size(0)] for x in range(len(outputs))] mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] sigma_pred_sq = output_list[2] - mu_pred_sq sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target + c2 lower = sigma_pred_sq + sigma_target_sq + c2 ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w] return reduce(ssim_idx, reduction)
def _spectral_distortion_index_compute( preds: Tensor, target: Tensor, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", ) -> Tensor: """Computes Spectral Distortion Index (SpectralDistortionIndex_) Args: preds: Low resolution multispectral image target: High resolution fused image p: a parameter to emphasize large spectral difference reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied Example: >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 16, 16]) >>> target = torch.rand([16, 3, 16, 16]) >>> preds, target = _spectral_distortion_index_update(preds, target) >>> _spectral_distortion_index_compute(preds, target) tensor(0.0234) """ length = preds.shape[1] m1 = torch.zeros((length, length)) m2 = torch.zeros((length, length)) for k in range(length): for r in range(k, length): m1[k, r] = m1[r, k] = universal_image_quality_index(target[:, k : k + 1, :, :], target[:, r : r + 1, :, :]) m2[k, r] = m2[r, k] = universal_image_quality_index(preds[:, k : k + 1, :, :], preds[:, r : r + 1, :, :]) diff = torch.pow(torch.abs(m1 - m2), p) # Special case: when number of channels (L) is 1, there will be only one element in M1 and M2. Hence no need to sum. if length == 1: output = torch.pow(diff, (1.0 / p)) else: output = torch.pow(1.0 / (length * (length - 1)) * torch.sum(diff), (1.0 / p)) return reduce(output, reduction)
def _jaccard_from_confmat( confmat: Tensor, num_classes: int, ignore_index: Optional[int] = None, absent_score: float = 0.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Computes the intersection over union from confusion matrix. Args: confmat: Confusion matrix without normalization num_classes: Number of classes for a given prediction and target tensor ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. absent_score: score to use for an individual class, if no instances of the class index were present in ``preds`` AND no instances of the class index were present in ``target``. reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied """ # Remove the ignored class index from the scores. if ignore_index is not None and 0 <= ignore_index < num_classes: confmat[ignore_index] = 0.0 intersection = torch.diag(confmat) union = confmat.sum(0) + confmat.sum(1) - intersection # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() scores[union == 0] = absent_score if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat([ scores[:ignore_index], scores[ignore_index + 1:], ]) return reduce(scores, reduction=reduction)
def _iou_from_confmat( confmat: Tensor, num_classes: int, ignore_index: Optional[int] = None, absent_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> Tensor: intersection = torch.diag(confmat) union = confmat.sum(0) + confmat.sum(1) - intersection # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() scores[union == 0] = absent_score # Remove the ignored class index from the scores. if ignore_index is not None and 0 <= ignore_index < num_classes: scores = torch.cat([ scores[:ignore_index], scores[ignore_index + 1:], ]) return reduce(scores, reduction=reduction)
def _ergas_compute( preds: Tensor, target: Tensor, ratio: Union[int, float] = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Erreur Relative Globale Adimensionnelle de Synthèse. Args: preds: estimated image target: ground truth image ratio: ratio of high resolution to low resolution reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied Example: >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> preds, target = _ergas_update(preds, target) >>> torch.round(_ergas_compute(preds, target)) tensor(154.) """ b, c, h, w = preds.shape preds = preds.reshape(b, c, h * w) target = target.reshape(b, c, h * w) diff = preds - target sum_squared_error = torch.sum(diff * diff, dim=2) rmse_per_band = torch.sqrt(sum_squared_error / (h * w)) mean_target = torch.mean(target, dim=2) ergas_score = 100 * ratio * torch.sqrt( torch.sum((rmse_per_band / mean_target)**2, dim=1) / c) return reduce(ergas_score, reduction)
def _ssim_compute( preds: Tensor, target: Tensor, gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", data_range: Optional[float] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Computes Structual Similarity Index Measure. Args: preds: estimated image target: ground truth image gaussian_kernel: If true (default), a gaussian kernel is used, if false a uniform kernel is used sigma: Standard deviation of the gaussian kernel, anisotropic kernels are possible. Ignored if a uniform kernel is used kernel_size: the size of the uniform kernel, anisotropic kernels are possible. Ignored if a Gaussian kernel is used reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. k2: Parameter of SSIM. return_full_image: If true, the full ``ssim`` image is returned as a second argument. Mutually exlusive with ``return_contrast_sensitivity`` return_contrast_sensitivity: If true, the contrast term is returned as a second argument. The luminance term can be obtained with luminance=ssim/contrast Mutually exclusive with ``return_full_image`` Example: >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> preds, target = _ssim_update(preds, target) >>> _ssim_compute(preds, target) tensor(0.9219) """ is_3d = len(preds.shape) == 5 if not isinstance(kernel_size, Sequence): kernel_size = 3 * [kernel_size] if is_3d else 2 * [kernel_size] if not isinstance(sigma, Sequence): sigma = 3 * [sigma] if is_3d else 2 * [sigma] if len(kernel_size) != len(target.shape) - 2: raise ValueError( f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality," f" which is: {len(target.shape)}") if len(kernel_size) not in (2, 3): raise ValueError( f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}" ) if len(sigma) != len(target.shape) - 2: raise ValueError( f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality," f" which is: {len(target.shape)}") if len(sigma) not in (2, 3): raise ValueError( f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}" ) if any(x % 2 == 0 or x <= 0 for x in kernel_size): raise ValueError( f"Expected `kernel_size` to have odd positive number. Got {kernel_size}." ) if any(y <= 0 for y in sigma): raise ValueError( f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: data_range = max(preds.max() - preds.min(), target.max() - target.min()) c1 = pow(k1 * data_range, 2) c2 = pow(k2 * data_range, 2) device = preds.device channel = preds.size(1) dtype = preds.dtype gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma] pad_h = (gauss_kernel_size[0] - 1) // 2 pad_w = (gauss_kernel_size[1] - 1) // 2 if is_3d: pad_d = (gauss_kernel_size[2] - 1) // 2 preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h) target = _reflection_pad_3d(target, pad_d, pad_w, pad_h) if gaussian_kernel: kernel = _gaussian_kernel_3d(channel, gauss_kernel_size, sigma, dtype, device) else: preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode="reflect") target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode="reflect") if gaussian_kernel: kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: kernel = torch.ones( (1, 1, *kernel_size)) / torch.prod(torch.Tensor(kernel_size)) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) if is_3d: outputs = F.conv3d(input_list, kernel, groups=channel) else: outputs = F.conv2d(input_list, kernel, groups=channel) output_list = outputs.split(preds.shape[0]) mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] sigma_pred_sq = output_list[2] - mu_pred_sq sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target + c2 lower = sigma_pred_sq + sigma_target_sq + c2 ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ( (mu_pred_sq + mu_target_sq + c1) * lower) if is_3d: ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d] else: ssim_idx = ssim_idx_full_image[..., pad_h:-pad_h, pad_w:-pad_w] if return_contrast_sensitivity: contrast_sensitivity = upper / lower contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w] return reduce( ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), reduction), reduce( contrast_sensitivity.reshape(contrast_sensitivity.shape[0], -1).mean(-1), reduction) elif return_full_image: return reduce( ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), reduction), reduce(ssim_idx_full_image, reduction) return reduce(ssim_idx.reshape(ssim_idx.shape[0], -1).mean(-1), reduction)
def _uqi_compute( preds: Tensor, target: Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean", data_range: Optional[float] = None, return_contrast_sensitivity: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Computes Universal Image Quality Index. Args: preds: estimated image target: ground truth image kernel_size: size of the gaussian kernel sigma: Standard deviation of the gaussian kernel reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'`` or ``None``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) Example: >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> preds, target = _uqi_update(preds, target) >>> _uqi_compute(preds, target) tensor(0.9216) """ if len(kernel_size) != 2 or len(sigma) != 2: raise ValueError( "Expected `kernel_size` and `sigma` to have the length of two." f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." ) if any(x % 2 == 0 or x <= 0 for x in kernel_size): raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") if any(y <= 0 for y in sigma): raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: data_range = max(preds.max() - preds.min(), target.max() - target.min()) device = preds.device channel = preds.size(1) dtype = preds.dtype kernel = _gaussian_kernel_2d(channel, kernel_size, sigma, dtype, device) pad_h = (kernel_size[0] - 1) // 2 pad_w = (kernel_size[1] - 1) // 2 preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect") target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect") input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) outputs = F.conv2d(input_list, kernel, groups=channel) output_list = outputs.split(preds.shape[0]) mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] sigma_pred_sq = output_list[2] - mu_pred_sq sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target lower = sigma_pred_sq + sigma_target_sq uqi_idx = ((2 * mu_pred_target) * upper) / ((mu_pred_sq + mu_target_sq) * lower) uqi_idx = uqi_idx[..., pad_h:-pad_h, pad_w:-pad_w] return reduce(uqi_idx, reduction)