def accuracy( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction='elementwise_mean', ) -> torch.Tensor: """ Computes the accuracy classification score Args: pred: predicted labels target: ground truth labels num_classes: number of classes reduction: a method for reducing accuracies over labels (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Return: A Tensor with the classification score. """ tps, fps, tns, fns, sups = stat_scores_multiple_classes( pred=pred, target=target, num_classes=num_classes) if not (target > 0).any() and num_classes is None: raise RuntimeError("cannot infer num_classes when target is all zero") if reduction in ('elementwise_mean', 'sum'): return reduce(sum(tps) / sum(sups), reduction=reduction) if reduction == 'none': return reduce(tps / sups, reduction=reduction)
def test_v1_3_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import to_onehot with pytest.deprecated_call(match='will be removed in v1.3'): to_onehot(torch.tensor([1, 2, 3])) from pytorch_lightning.metrics.functional.classification import to_categorical with pytest.deprecated_call(match='will be removed in v1.3'): to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) from pytorch_lightning.metrics.functional.classification import get_num_classes with pytest.deprecated_call(match='will be removed in v1.3'): get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) x_binary = torch.tensor([0, 1, 2, 3]) y_binary = torch.tensor([0, 1, 2, 3]) from pytorch_lightning.metrics.functional.classification import roc with pytest.deprecated_call(match='will be removed in v1.3'): roc(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import _roc with pytest.deprecated_call(match='will be removed in v1.3'): _roc(pred=x_binary, target=y_binary) x_multy = torch.tensor([ [0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05], [0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85], ]) y_multy = torch.tensor([0, 1, 3, 2]) from pytorch_lightning.metrics.functional.classification import multiclass_roc with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_roc(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.classification import average_precision with pytest.deprecated_call(match='will be removed in v1.3'): average_precision(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): precision_recall_curve(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_precision_recall_curve(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.reduction import reduce with pytest.deprecated_call(match='will be removed in v1.3'): reduce(torch.tensor([0, 1, 1, 0]), 'sum') from pytorch_lightning.metrics.functional.reduction import class_reduce with pytest.deprecated_call(match='will be removed in v1.3'): class_reduce( torch.randint(1, 10, (50, )).float(), torch.randint(10, 20, (50, )).float(), torch.randint(1, 100, (50, )).float())
def test_reduce(): start_tensor = torch.rand(50, 40, 30) assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor)) assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor)) assert torch.allclose(reduce(start_tensor, 'none'), start_tensor) with pytest.raises(ValueError): reduce(start_tensor, 'error_reduction')
def precision_recall( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds Args: pred: estimated probabilities target: ground-truth labels num_classes: number of classes reduction: method for reducing precision-recall values (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Return: Tensor with precision and recall Example: >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> precision_recall(x, y) (tensor(0.7500), tensor(0.6250)) """ tps, fps, tns, fns, sups = stat_scores_multiple_classes( pred=pred, target=target, num_classes=num_classes) tps = tps.to(torch.float) fps = fps.to(torch.float) fns = fns.to(torch.float) precision = tps / (tps + fps) recall = tps / (tps + fns) # solution by justus, see https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/9 precision[precision != precision] = 0 recall[recall != recall] = 0 precision = reduce(precision, reduction=reduction) recall = reduce(recall, reduction=reduction) return precision, recall
def mse(pred: torch.Tensor, target: torch.Tensor, reduction: str = 'elementwise_mean') -> torch.Tensor: """ Computes mean squared error Args: pred: estimated labels target: ground truth labels reduction: a method to reduce metric score over labels (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Return: Tensor with MSE Example: >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mse(x, y) tensor(0.2500) """ mse = F.mse_loss(pred, target, reduction='none') mse = reduce(mse, reduction=reduction) return mse
def dice_score( pred: torch.Tensor, target: torch.Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: n_classes = pred.shape[1] bg = (1 - int(bool(bg))) scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32) for i in range(bg, n_classes): if not (target == i).any(): # no foreground class scores[i - bg] += no_fg_score continue tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) if torch.isclose(denom, torch.zeros_like(denom)).any(): # nan result score_cls = nan_score else: score_cls = (2 * tp).to(torch.float) / denom scores[i - bg] += score_cls return reduce(scores, reduction=reduction)
def mse( pred: torch.Tensor, target: torch.Tensor, reduction: str = 'elementwise_mean' ) -> torch.Tensor: """ Computes mean squared error Args: pred: estimated labels target: ground truth labels reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied Return: Tensor with MSE Example: >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mse(x, y) tensor(0.2500) """ mse = F.mse_loss(pred, target, reduction='none') mse = reduce(mse, reduction=reduction) return mse
def mae(pred: torch.Tensor, target: torch.Tensor, reduction: str = 'elementwise_mean') -> torch.Tensor: """ Computes mean absolute error Args: pred: estimated labels target: ground truth labels reduction: method for reducing mae (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Return: Tensor with MAE Example: >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mae(x, y) tensor(0.2500) """ mae = F.l1_loss(pred, target, reduction='none') mae = reduce(mae, reduction=reduction) return mae
def compute(self): scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32) for class_idx in range(self.n_classes): if class_idx == self.ignore_index: continue tp = self.true_positive[class_idx] fp = self.false_positive[class_idx] fn = self.false_negative[class_idx] sup = self.support[class_idx] # If this class is absent in the target (no support) AND absent in the pred (no true or false # positives), then use the absent_score for this class. if sup + tp + fp == 0: scores[class_idx] = self.absent_score continue denominator = tp + fp + fn score = tp.to(torch.float) / denominator scores[class_idx] = score # Remove the ignored class index from the scores. if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes): scores = torch.cat( [scores[:self.ignore_index], scores[self.ignore_index + 1:]]) return reduce(scores, reduction=self.reduction)
def dice_score( pred: torch.Tensor, target: torch.Tensor, bg: bool = False, nan_score: float = 0.0, no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Compute dice score from prediction scores Args: pred: estimated probabilities target: ground-truth labels bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation no_fg_score: score to return, if no foreground pixel was found in target reduction: a method for reducing accuracies over labels (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Return: Tensor containing dice score Example: >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], ... [0.05, 0.85, 0.05, 0.05], ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> dice_score(pred, target) tensor(0.3333) """ num_classes = pred.shape[1] bg = (1 - int(bool(bg))) scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) for i in range(bg, num_classes): if not (target == i).any(): # no foreground class scores[i - bg] += no_fg_score continue tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) # nan result score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero( denom) else nan_score scores[i - bg] += score_cls return reduce(scores, reduction=reduction)
def _ssim_compute( preds: torch.Tensor, target: torch.Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: str = "elementwise_mean", data_range: Optional[float] = None, k1: float = 0.01, k2: float = 0.03, ): if len(kernel_size) != 2 or len(sigma) != 2: raise ValueError( "Expected `kernel_size` and `sigma` to have the length of two." f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." ) if any(x % 2 == 0 or x <= 0 for x in kernel_size): raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") if any(y <= 0 for y in sigma): raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: data_range = max(preds.max() - preds.min(), target.max() - target.min()) c1 = pow(k1 * data_range, 2) c2 = pow(k2 * data_range, 2) device = preds.device channel = preds.size(1) kernel = _gaussian_kernel(channel, kernel_size, sigma, device) input_list = torch.cat([preds, target, preds * preds, target * target, preds * target]) # (5 * B, C, H, W) outputs = F.conv2d(input_list, kernel, groups=channel) output_list = [outputs[x * preds.size(0): (x + 1) * preds.size(0)] for x in range(len(outputs))] mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] sigma_pred_sq = output_list[2] - mu_pred_sq sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target + c2 lower = sigma_pred_sq + sigma_target_sq + c2 ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) return reduce(ssim_idx, reduction)
def fbeta_score( pred: torch.Tensor, target: torch.Tensor, beta: float, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Computes the F-beta score which is a weighted harmonic mean of precision and recall. It ranges between 1 and 0, where 1 is perfect and the worst value is 0. Args: pred: estimated probabilities target: ground-truth labels beta: weights recall when combining the score. beta < 1: more weight to precision. beta > 1 more weight to recall beta = 0: only precision beta -> inf: only recall num_classes: number of classes reduction: method for reducing F-score (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements. Return: Tensor with the value of F-score. It is a value between 0-1. Example: >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> fbeta_score(x, y, 0.2) tensor(0.7407) """ prec, rec = precision_recall(pred=pred, target=target, num_classes=num_classes, reduction='none') nom = (1 + beta**2) * prec * rec denom = ((beta**2) * prec + rec) fbeta = nom / denom # drop NaN after zero division fbeta[fbeta != fbeta] = 0 return reduce(fbeta, reduction=reduction)
def precision_recall( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds Args: pred: estimated probabilities target: ground-truth labels num_classes: number of classes reduction: method for reducing precision-recall values (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Return: Tensor with precision and recall """ tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) tps = tps.to(torch.float) fps = fps.to(torch.float) fns = fns.to(torch.float) precision = tps / (tps + fps) recall = tps / (tps + fns) precision = reduce(precision, reduction=reduction) recall = reduce(recall, reduction=reduction) return precision, recall
def iou(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, remove_bg: bool = False, reduction: str = 'elementwise_mean'): """ Intersection over union, or Jaccard index calculation. Args: pred: Tensor containing predictions target: Tensor containing targets num_classes: Optionally specify the number of classes remove_bg: Flag to state whether a background class has been included within input parameters. If true, will remove background class. If false, return IoU over all classes. Assumes that background is '0' class in input tensor reduction: a method for reducing IoU over labels (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean - none: pass array - sum: add elements Returns: IoU score : Tensor containing single value if reduction is 'elementwise_mean', or number of classes if reduction is 'none' Example: >>> target = torch.randint(0, 1, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> iou(pred, target) tensor(0.4914) """ tps, fps, tns, fns, sups = stat_scores_multiple_classes( pred, target, num_classes) if remove_bg: tps = tps[1:] fps = fps[1:] fns = fns[1:] iou = tps / (fps + fns + tps) return reduce(iou, reduction=reduction)
def _iou_from_confmat( confmat: torch.Tensor, num_classes: int, ignore_index: Optional[int] = None, absent_score: float = 0.0, reduction: str = 'elementwise_mean', ): intersection = torch.diag(confmat) union = confmat.sum(0) + confmat.sum(1) - intersection # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. scores = intersection.float() / union.float() scores[union == 0] = absent_score # Remove the ignored class index from the scores. if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: scores = torch.cat([ scores[:ignore_index], scores[ignore_index + 1:], ]) return reduce(scores, reduction=reduction)
def mae(pred: torch.Tensor, target: torch.Tensor, reduction: str = 'elementwise_mean', return_state: bool = False) -> torch.Tensor: """ Computes mean absolute error Args: pred: estimated labels target: ground truth labels reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied return_state: returns a internal state that can be ddp reduced before doing the final calculation Return: Tensor with MAE Example: >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mae(x, y) tensor(0.2500) """ mae = F.l1_loss(pred, target, reduction='none') if return_state: return { 'absolute_error': mae.sum(), 'n_observations': torch.tensor(mae.numel()) } mae = reduce(mae, reduction=reduction) return mae
def iou( pred: torch.Tensor, target: torch.Tensor, ignore_index: Optional[int] = None, absent_score: float = 0.0, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Intersection over union, or Jaccard index calculation. Args: pred: Tensor containing predictions target: Tensor containing targets ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no index is ignored, and all classes are used. absent_score: score to use for an individual class, if no instances of the class index were present in `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is 0.0. num_classes: Optionally specify the number of classes reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied Return: IoU score : Tensor containing single value if reduction is 'elementwise_mean', or number of classes if reduction is 'none' Example: >>> target = torch.randint(0, 1, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> iou(pred, target) tensor(0.4914) """ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) tps, fps, tns, fns, sups = stat_scores_multiple_classes( pred, target, num_classes) scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32) for class_idx in range(num_classes): if class_idx == ignore_index: continue tp = tps[class_idx] fp = fps[class_idx] fn = fns[class_idx] sup = sups[class_idx] # If this class is absent in the target (no support) AND absent in the pred (no true or false # positives), then use the absent_score for this class. if sup + tp + fp == 0: scores[class_idx] = absent_score continue denom = tp + fp + fn # Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above, # which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we # can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class. score = tp.to(torch.float) / denom scores[class_idx] = score # Remove the ignored class index from the scores. if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: scores = torch.cat([ scores[:ignore_index], scores[ignore_index + 1:], ]) return reduce(scores, reduction=reduction)
def ssim(pred: torch.Tensor, target: torch.Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: str = "elementwise_mean", data_range: float = None, k1: float = 0.01, k2: float = 0.03) -> torch.Tensor: """ Computes Structual Similarity Index Measure Args: pred: estimated image target: ground truth image kernel_size: size of the gaussian kernel (default: (11, 11)) sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. Default: 0.01 k2: Parameter of SSIM. Default: 0.03 Return: Tensor with SSIM score Example: >>> pred = torch.rand([16, 1, 16, 16]) >>> target = pred * 0.75 >>> ssim(pred, target) tensor(0.9219) """ if pred.dtype != target.dtype: raise TypeError( "Expected `pred` and `target` to have the same data type." f" Got pred: {pred.dtype} and target: {target.dtype}.") if pred.shape != target.shape: raise ValueError( "Expected `pred` and `target` to have the same shape." f" Got pred: {pred.shape} and target: {target.shape}.") if len(pred.shape) != 4 or len(target.shape) != 4: raise ValueError( "Expected `pred` and `target` to have BxCxHxW shape." f" Got pred: {pred.shape} and target: {target.shape}.") if len(kernel_size) != 2 or len(sigma) != 2: raise ValueError( "Expected `kernel_size` and `sigma` to have the length of two." f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}.") if any(x % 2 == 0 or x <= 0 for x in kernel_size): raise ValueError( f"Expected `kernel_size` to have odd positive number. Got {kernel_size}." ) if any(y <= 0 for y in sigma): raise ValueError( f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: data_range = max(pred.max() - pred.min(), target.max() - target.min()) C1 = pow(k1 * data_range, 2) C2 = pow(k2 * data_range, 2) device = pred.device channel = pred.size(1) kernel = _gaussian_kernel(channel, kernel_size, sigma, device) # Concatenate # pred for mu_pred # target for mu_target # pred * pred for sigma_pred # target * target for sigma_target # pred * target for sigma_pred_target input_list = torch.cat( [pred, target, pred * pred, target * target, pred * target]) # (5 * B, C, H, W) outputs = F.conv2d(input_list, kernel, groups=channel) output_list = [ outputs[x * pred.size(0):(x + 1) * pred.size(0)] for x in range(len(outputs)) ] mu_pred_sq = output_list[0].pow(2) mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] sigma_pred_sq = output_list[2] - mu_pred_sq sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target UPPER = 2 * sigma_pred_target + C2 LOWER = sigma_pred_sq + sigma_target_sq + C2 ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ( (mu_pred_sq + mu_target_sq + C1) * LOWER) return reduce(ssim_idx, reduction)
def ssim(pred: torch.Tensor, target: torch.Tensor, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: str = "elementwise_mean", data_range: float = None, k1: float = 0.01, k2: float = 0.03) -> torch.Tensor: """ Computes Structual Similarity Index Measure Args: pred: Estimated image target: Ground truth image kernel_size: Size of the gaussian kernel. Default: (11, 11) sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5) reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean`` Available reduction methods: - elementwise_mean: takes the mean - none: pass away - sum: add elements data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. Default: 0.01 k2: Parameter of SSIM. Default: 0.03 Returns: A Tensor with SSIM Example: >>> pred = torch.rand([16, 1, 16, 16]) >>> target = pred * 1.25 >>> ssim(pred, target) tensor(0.9520) """ if pred.dtype != target.dtype: raise TypeError( "Expected `pred` and `target` to have the same data type." f" Got pred: {pred.dtype} and target: {target.dtype}.") if pred.shape != target.shape: raise ValueError( "Expected `pred` and `target` to have the same shape." f" Got pred: {pred.shape} and target: {target.shape}.") if len(pred.shape) != 4 or len(target.shape) != 4: raise ValueError( "Expected `pred` and `target` to have BxCxHxW shape." f" Got pred: {pred.shape} and target: {target.shape}.") if len(kernel_size) != 2 or len(sigma) != 2: raise ValueError( "Expected `kernel_size` and `sigma` to have the length of two." f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}.") if any(x % 2 == 0 or x <= 0 for x in kernel_size): raise ValueError( f"Expected `kernel_size` to have odd positive number. Got {kernel_size}." ) if any(y <= 0 for y in sigma): raise ValueError( f"Expected `sigma` to have positive number. Got {sigma}.") if data_range is None: data_range = max(pred.max() - pred.min(), target.max() - target.min()) C1 = pow(k1 * data_range, 2) C2 = pow(k2 * data_range, 2) device = pred.device channel = pred.size(1) kernel = _gaussian_kernel(channel, kernel_size, sigma, device) mu_pred = F.conv2d(pred, kernel, groups=channel) mu_target = F.conv2d(target, kernel, groups=channel) mu_pred_sq = mu_pred.pow(2) mu_target_sq = mu_target.pow(2) mu_pred_target = mu_pred * mu_target sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target UPPER = 2 * sigma_pred_target + C2 LOWER = sigma_pred_sq + sigma_target_sq + C2 ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ( (mu_pred_sq + mu_target_sq + C1) * LOWER) return reduce(ssim_idx, reduction)