Esempio n. 1
0
def accuracy(
    pred: torch.Tensor,
    target: torch.Tensor,
    num_classes: Optional[int] = None,
    reduction='elementwise_mean',
) -> torch.Tensor:
    """
    Computes the accuracy classification score

    Args:
        pred: predicted labels
        target: ground truth labels
        num_classes: number of classes
        reduction: a method for reducing accuracies over labels (default: takes the mean)
           Available reduction methods:

           - elementwise_mean: takes the mean
           - none: pass array
           - sum: add elements

    Return:
         A Tensor with the classification score.
    """
    tps, fps, tns, fns, sups = stat_scores_multiple_classes(
        pred=pred, target=target, num_classes=num_classes)

    if not (target > 0).any() and num_classes is None:
        raise RuntimeError("cannot infer num_classes when target is all zero")

    if reduction in ('elementwise_mean', 'sum'):
        return reduce(sum(tps) / sum(sups), reduction=reduction)
    if reduction == 'none':
        return reduce(tps / sups, reduction=reduction)
def test_v1_3_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import to_onehot
    with pytest.deprecated_call(match='will be removed in v1.3'):
        to_onehot(torch.tensor([1, 2, 3]))

    from pytorch_lightning.metrics.functional.classification import to_categorical
    with pytest.deprecated_call(match='will be removed in v1.3'):
        to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]]))

    from pytorch_lightning.metrics.functional.classification import get_num_classes
    with pytest.deprecated_call(match='will be removed in v1.3'):
        get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1]))

    x_binary = torch.tensor([0, 1, 2, 3])
    y_binary = torch.tensor([0, 1, 2, 3])

    from pytorch_lightning.metrics.functional.classification import roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        roc(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import _roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        _roc(pred=x_binary, target=y_binary)

    x_multy = torch.tensor([
        [0.85, 0.05, 0.05, 0.05],
        [0.05, 0.85, 0.05, 0.05],
        [0.05, 0.05, 0.85, 0.05],
        [0.05, 0.05, 0.05, 0.85],
    ])
    y_multy = torch.tensor([0, 1, 3, 2])

    from pytorch_lightning.metrics.functional.classification import multiclass_roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        multiclass_roc(pred=x_multy, target=y_multy)

    from pytorch_lightning.metrics.functional.classification import average_precision
    with pytest.deprecated_call(match='will be removed in v1.3'):
        average_precision(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import precision_recall_curve
    with pytest.deprecated_call(match='will be removed in v1.3'):
        precision_recall_curve(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
    with pytest.deprecated_call(match='will be removed in v1.3'):
        multiclass_precision_recall_curve(pred=x_multy, target=y_multy)

    from pytorch_lightning.metrics.functional.reduction import reduce
    with pytest.deprecated_call(match='will be removed in v1.3'):
        reduce(torch.tensor([0, 1, 1, 0]), 'sum')

    from pytorch_lightning.metrics.functional.reduction import class_reduce
    with pytest.deprecated_call(match='will be removed in v1.3'):
        class_reduce(
            torch.randint(1, 10, (50, )).float(),
            torch.randint(10, 20, (50, )).float(),
            torch.randint(1, 100, (50, )).float())
def test_reduce():
    start_tensor = torch.rand(50, 40, 30)

    assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
    assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
    assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)

    with pytest.raises(ValueError):
        reduce(start_tensor, 'error_reduction')
def precision_recall(
    pred: torch.Tensor,
    target: torch.Tensor,
    num_classes: Optional[int] = None,
    reduction: str = 'elementwise_mean',
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes precision and recall for different thresholds

    Args:
        pred: estimated probabilities
        target: ground-truth labels
        num_classes: number of classes
        reduction: method for reducing precision-recall values (default: takes the mean)
           Available reduction methods:

           - elementwise_mean: takes the mean
           - none: pass array
           - sum: add elements

    Return:
        Tensor with precision and recall

    Example:

        >>> x = torch.tensor([0, 1, 2, 3])
        >>> y = torch.tensor([0, 1, 2, 2])
        >>> precision_recall(x, y)
        (tensor(0.7500), tensor(0.6250))

    """
    tps, fps, tns, fns, sups = stat_scores_multiple_classes(
        pred=pred, target=target, num_classes=num_classes)

    tps = tps.to(torch.float)
    fps = fps.to(torch.float)
    fns = fns.to(torch.float)

    precision = tps / (tps + fps)
    recall = tps / (tps + fns)

    # solution by justus, see https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/9
    precision[precision != precision] = 0
    recall[recall != recall] = 0

    precision = reduce(precision, reduction=reduction)
    recall = reduce(recall, reduction=reduction)
    return precision, recall
Esempio n. 5
0
def mse(pred: torch.Tensor,
        target: torch.Tensor,
        reduction: str = 'elementwise_mean') -> torch.Tensor:
    """
    Computes mean squared error

    Args:
        pred: estimated labels
        target: ground truth labels
        reduction: a method to reduce metric score over labels (default: takes the mean)
            Available reduction methods:

            - elementwise_mean: takes the mean
            - none: pass array
            - sum: add elements

    Return:
        Tensor with MSE

    Example:

        >>> x = torch.tensor([0., 1, 2, 3])
        >>> y = torch.tensor([0., 1, 2, 2])
        >>> mse(x, y)
        tensor(0.2500)

    """
    mse = F.mse_loss(pred, target, reduction='none')
    mse = reduce(mse, reduction=reduction)
    return mse
def dice_score(
    pred: torch.Tensor,
    target: torch.Tensor,
    bg: bool = False,
    nan_score: float = 0.0,
    no_fg_score: float = 0.0,
    reduction: str = 'elementwise_mean',
) -> torch.Tensor:
    n_classes = pred.shape[1]
    bg = (1 - int(bool(bg)))
    scores = torch.zeros(n_classes - bg,
                         device=pred.device,
                         dtype=torch.float32)
    for i in range(bg, n_classes):
        if not (target == i).any():
            # no foreground class
            scores[i - bg] += no_fg_score
            continue

        tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)

        denom = (2 * tp + fp + fn).to(torch.float)

        if torch.isclose(denom, torch.zeros_like(denom)).any():
            # nan result
            score_cls = nan_score
        else:
            score_cls = (2 * tp).to(torch.float) / denom

        scores[i - bg] += score_cls
    return reduce(scores, reduction=reduction)
Esempio n. 7
0
def mse(
        pred: torch.Tensor,
        target: torch.Tensor,
        reduction: str = 'elementwise_mean'
) -> torch.Tensor:
    """
    Computes mean squared error

    Args:
        pred: estimated labels
        target: ground truth labels
        reduction: a method to reduce metric score over labels.

            - ``'elementwise_mean'``: takes the mean (default)
            - ``'sum'``: takes the sum
            - ``'none'``: no reduction will be applied

    Return:
        Tensor with MSE

    Example:

        >>> x = torch.tensor([0., 1, 2, 3])
        >>> y = torch.tensor([0., 1, 2, 2])
        >>> mse(x, y)
        tensor(0.2500)

    """
    mse = F.mse_loss(pred, target, reduction='none')
    mse = reduce(mse, reduction=reduction)
    return mse
Esempio n. 8
0
def mae(pred: torch.Tensor,
        target: torch.Tensor,
        reduction: str = 'elementwise_mean') -> torch.Tensor:
    """
    Computes mean absolute error

    Args:
        pred: estimated labels
        target: ground truth labels
        reduction: method for reducing mae (default: takes the mean)
            Available reduction methods:

            - elementwise_mean: takes the mean
            - none: pass array
            - sum: add elements

    Return:
        Tensor with MAE

    Example:

        >>> x = torch.tensor([0., 1, 2, 3])
        >>> y = torch.tensor([0., 1, 2, 2])
        >>> mae(x, y)
        tensor(0.2500)

    """
    mae = F.l1_loss(pred, target, reduction='none')
    mae = reduce(mae, reduction=reduction)
    return mae
Esempio n. 9
0
    def compute(self):
        scores = torch.zeros(self.n_classes,
                             device=self.true_positive.device,
                             dtype=torch.float32)

        for class_idx in range(self.n_classes):
            if class_idx == self.ignore_index:
                continue

            tp = self.true_positive[class_idx]
            fp = self.false_positive[class_idx]
            fn = self.false_negative[class_idx]
            sup = self.support[class_idx]

            # If this class is absent in the target (no support) AND absent in the pred (no true or false
            # positives), then use the absent_score for this class.
            if sup + tp + fp == 0:
                scores[class_idx] = self.absent_score
                continue

            denominator = tp + fp + fn
            score = tp.to(torch.float) / denominator
            scores[class_idx] = score

        # Remove the ignored class index from the scores.
        if (self.ignore_index
                is not None) and (0 <= self.ignore_index < self.n_classes):
            scores = torch.cat(
                [scores[:self.ignore_index], scores[self.ignore_index + 1:]])

        return reduce(scores, reduction=self.reduction)
Esempio n. 10
0
def dice_score(
    pred: torch.Tensor,
    target: torch.Tensor,
    bg: bool = False,
    nan_score: float = 0.0,
    no_fg_score: float = 0.0,
    reduction: str = 'elementwise_mean',
) -> torch.Tensor:
    """
    Compute dice score from prediction scores

    Args:
        pred: estimated probabilities
        target: ground-truth labels
        bg: whether to also compute dice for the background
        nan_score: score to return, if a NaN occurs during computation
        no_fg_score: score to return, if no foreground pixel was found in target
        reduction: a method for reducing accuracies over labels (default: takes the mean)
            Available reduction methods:

            - elementwise_mean: takes the mean
            - none: pass array
            - sum: add elements

    Return:
        Tensor containing dice score

    Example:

        >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
        ...                      [0.05, 0.85, 0.05, 0.05],
        ...                      [0.05, 0.05, 0.85, 0.05],
        ...                      [0.05, 0.05, 0.05, 0.85]])
        >>> target = torch.tensor([0, 1, 3, 2])
        >>> dice_score(pred, target)
        tensor(0.3333)

    """
    num_classes = pred.shape[1]
    bg = (1 - int(bool(bg)))
    scores = torch.zeros(num_classes - bg,
                         device=pred.device,
                         dtype=torch.float32)
    for i in range(bg, num_classes):
        if not (target == i).any():
            # no foreground class
            scores[i - bg] += no_fg_score
            continue

        tp, fp, tn, fn, sup = stat_scores(pred=pred,
                                          target=target,
                                          class_index=i)
        denom = (2 * tp + fp + fn).to(torch.float)
        # nan result
        score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(
            denom) else nan_score

        scores[i - bg] += score_cls
    return reduce(scores, reduction=reduction)
Esempio n. 11
0
def _ssim_compute(
    preds: torch.Tensor,
    target: torch.Tensor,
    kernel_size: Sequence[int] = (11, 11),
    sigma: Sequence[float] = (1.5, 1.5),
    reduction: str = "elementwise_mean",
    data_range: Optional[float] = None,
    k1: float = 0.01,
    k2: float = 0.03,
):
    if len(kernel_size) != 2 or len(sigma) != 2:
        raise ValueError(
            "Expected `kernel_size` and `sigma` to have the length of two."
            f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
        )

    if any(x % 2 == 0 or x <= 0 for x in kernel_size):
        raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")

    if any(y <= 0 for y in sigma):
        raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")

    if data_range is None:
        data_range = max(preds.max() - preds.min(), target.max() - target.min())

    c1 = pow(k1 * data_range, 2)
    c2 = pow(k2 * data_range, 2)
    device = preds.device

    channel = preds.size(1)
    kernel = _gaussian_kernel(channel, kernel_size, sigma, device)

    input_list = torch.cat([preds, target, preds * preds, target * target, preds * target])  # (5 * B, C, H, W)
    outputs = F.conv2d(input_list, kernel, groups=channel)
    output_list = [outputs[x * preds.size(0): (x + 1) * preds.size(0)] for x in range(len(outputs))]

    mu_pred_sq = output_list[0].pow(2)
    mu_target_sq = output_list[1].pow(2)
    mu_pred_target = output_list[0] * output_list[1]

    sigma_pred_sq = output_list[2] - mu_pred_sq
    sigma_target_sq = output_list[3] - mu_target_sq
    sigma_pred_target = output_list[4] - mu_pred_target

    upper = 2 * sigma_pred_target + c2
    lower = sigma_pred_sq + sigma_target_sq + c2

    ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)

    return reduce(ssim_idx, reduction)
def fbeta_score(
    pred: torch.Tensor,
    target: torch.Tensor,
    beta: float,
    num_classes: Optional[int] = None,
    reduction: str = 'elementwise_mean',
) -> torch.Tensor:
    """
    Computes the F-beta score which is a weighted harmonic mean of precision and recall.
    It ranges between 1 and 0, where 1 is perfect and the worst value is 0.

    Args:
        pred: estimated probabilities
        target: ground-truth labels
        beta: weights recall when combining the score.
            beta < 1: more weight to precision.
            beta > 1 more weight to recall
            beta = 0: only precision
            beta -> inf: only recall
        num_classes: number of classes
        reduction: method for reducing F-score (default: takes the mean)
           Available reduction methods:

           - elementwise_mean: takes the mean
           - none: pass array
           - sum: add elements.

    Return:
        Tensor with the value of F-score. It is a value between 0-1.

    Example:

        >>> x = torch.tensor([0, 1, 2, 3])
        >>> y = torch.tensor([0, 1, 2, 2])
        >>> fbeta_score(x, y, 0.2)
        tensor(0.7407)
    """
    prec, rec = precision_recall(pred=pred,
                                 target=target,
                                 num_classes=num_classes,
                                 reduction='none')

    nom = (1 + beta**2) * prec * rec
    denom = ((beta**2) * prec + rec)
    fbeta = nom / denom

    # drop NaN after zero division
    fbeta[fbeta != fbeta] = 0

    return reduce(fbeta, reduction=reduction)
Esempio n. 13
0
def precision_recall(
    pred: torch.Tensor,
    target: torch.Tensor,
    num_classes: Optional[int] = None,
    reduction: str = 'elementwise_mean',
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes precision and recall for different thresholds

    Args:
        pred: estimated probabilities
        target: ground-truth labels
        num_classes: number of classes
        reduction: method for reducing precision-recall values (default: takes the mean)
           Available reduction methods:

           - elementwise_mean: takes the mean
           - none: pass array
           - sum: add elements

    Return:
        Tensor with precision and recall
    """
    tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
                                                      target=target,
                                                      num_classes=num_classes)

    tps = tps.to(torch.float)
    fps = fps.to(torch.float)
    fns = fns.to(torch.float)

    precision = tps / (tps + fps)
    recall = tps / (tps + fns)

    precision = reduce(precision, reduction=reduction)
    recall = reduce(recall, reduction=reduction)
    return precision, recall
def iou(pred: torch.Tensor,
        target: torch.Tensor,
        num_classes: Optional[int] = None,
        remove_bg: bool = False,
        reduction: str = 'elementwise_mean'):
    """
    Intersection over union, or Jaccard index calculation.

    Args:
        pred: Tensor containing predictions

        target: Tensor containing targets

        num_classes: Optionally specify the number of classes

        remove_bg: Flag to state whether a background class has been included
            within input parameters. If true, will remove background class. If
            false, return IoU over all classes.
            Assumes that background is '0' class in input tensor

        reduction: a method for reducing IoU over labels (default: takes the mean)
            Available reduction methods:
            - elementwise_mean: takes the mean
            - none: pass array
            - sum: add elements

    Returns:
        IoU score : Tensor containing single value if reduction is
        'elementwise_mean', or number of classes if reduction is 'none'

    Example:

        >>> target = torch.randint(0, 1, (10, 25, 25))
        >>> pred = torch.tensor(target)
        >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
        >>> iou(pred, target)
        tensor(0.4914)

    """
    tps, fps, tns, fns, sups = stat_scores_multiple_classes(
        pred, target, num_classes)
    if remove_bg:
        tps = tps[1:]
        fps = fps[1:]
        fns = fns[1:]
    iou = tps / (fps + fns + tps)
    return reduce(iou, reduction=reduction)
Esempio n. 15
0
def _iou_from_confmat(
    confmat: torch.Tensor,
    num_classes: int,
    ignore_index: Optional[int] = None,
    absent_score: float = 0.0,
    reduction: str = 'elementwise_mean',
):
    intersection = torch.diag(confmat)
    union = confmat.sum(0) + confmat.sum(1) - intersection

    # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
    scores = intersection.float() / union.float()
    scores[union == 0] = absent_score

    # Remove the ignored class index from the scores.
    if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
        scores = torch.cat([
            scores[:ignore_index],
            scores[ignore_index + 1:],
        ])
    return reduce(scores, reduction=reduction)
Esempio n. 16
0
def mae(pred: torch.Tensor,
        target: torch.Tensor,
        reduction: str = 'elementwise_mean',
        return_state: bool = False) -> torch.Tensor:
    """
    Computes mean absolute error

    Args:
        pred: estimated labels
        target: ground truth labels
        reduction: a method to reduce metric score over labels.

            - ``'elementwise_mean'``: takes the mean (default)
            - ``'sum'``: takes the sum
            - ``'none'``: no reduction will be applied
        return_state: returns a internal state that can be ddp reduced
            before doing the final calculation

    Return:
        Tensor with MAE

    Example:

        >>> x = torch.tensor([0., 1, 2, 3])
        >>> y = torch.tensor([0., 1, 2, 2])
        >>> mae(x, y)
        tensor(0.2500)

    """
    mae = F.l1_loss(pred, target, reduction='none')
    if return_state:
        return {
            'absolute_error': mae.sum(),
            'n_observations': torch.tensor(mae.numel())
        }
    mae = reduce(mae, reduction=reduction)
    return mae
Esempio n. 17
0
def iou(
    pred: torch.Tensor,
    target: torch.Tensor,
    ignore_index: Optional[int] = None,
    absent_score: float = 0.0,
    num_classes: Optional[int] = None,
    reduction: str = 'elementwise_mean',
) -> torch.Tensor:
    """
    Intersection over union, or Jaccard index calculation.

    Args:
        pred: Tensor containing predictions
        target: Tensor containing targets
        ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
            to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
            range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no
            index is ignored, and all classes are used.
        absent_score: score to use for an individual class, if no instances of the class index were present in
            `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes,
            [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is
            0.0.
        num_classes: Optionally specify the number of classes
        reduction: a method to reduce metric score over labels.

            - ``'elementwise_mean'``: takes the mean (default)
            - ``'sum'``: takes the sum
            - ``'none'``: no reduction will be applied

    Return:
        IoU score : Tensor containing single value if reduction is
        'elementwise_mean', or number of classes if reduction is 'none'

    Example:

        >>> target = torch.randint(0, 1, (10, 25, 25))
        >>> pred = torch.tensor(target)
        >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
        >>> iou(pred, target)
        tensor(0.4914)

    """
    num_classes = get_num_classes(pred=pred,
                                  target=target,
                                  num_classes=num_classes)

    tps, fps, tns, fns, sups = stat_scores_multiple_classes(
        pred, target, num_classes)

    scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32)

    for class_idx in range(num_classes):
        if class_idx == ignore_index:
            continue

        tp = tps[class_idx]
        fp = fps[class_idx]
        fn = fns[class_idx]
        sup = sups[class_idx]

        # If this class is absent in the target (no support) AND absent in the pred (no true or false
        # positives), then use the absent_score for this class.
        if sup + tp + fp == 0:
            scores[class_idx] = absent_score
            continue

        denom = tp + fp + fn
        # Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above,
        # which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we
        # can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class.
        score = tp.to(torch.float) / denom
        scores[class_idx] = score

    # Remove the ignored class index from the scores.
    if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
        scores = torch.cat([
            scores[:ignore_index],
            scores[ignore_index + 1:],
        ])

    return reduce(scores, reduction=reduction)
Esempio n. 18
0
def ssim(pred: torch.Tensor,
         target: torch.Tensor,
         kernel_size: Sequence[int] = (11, 11),
         sigma: Sequence[float] = (1.5, 1.5),
         reduction: str = "elementwise_mean",
         data_range: float = None,
         k1: float = 0.01,
         k2: float = 0.03) -> torch.Tensor:
    """
    Computes Structual Similarity Index Measure

    Args:
        pred: estimated image
        target: ground truth image
        kernel_size: size of the gaussian kernel (default: (11, 11))
        sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
        reduction: a method to reduce metric score over labels.

            - ``'elementwise_mean'``: takes the mean (default)
            - ``'sum'``: takes the sum
            - ``'none'``: no reduction will be applied

        data_range: Range of the image. If ``None``, it is determined from the image (max - min)
        k1: Parameter of SSIM. Default: 0.01
        k2: Parameter of SSIM. Default: 0.03

    Return:
        Tensor with SSIM score

    Example:

        >>> pred = torch.rand([16, 1, 16, 16])
        >>> target = pred * 0.75
        >>> ssim(pred, target)
        tensor(0.9219)

    """
    if pred.dtype != target.dtype:
        raise TypeError(
            "Expected `pred` and `target` to have the same data type."
            f" Got pred: {pred.dtype} and target: {target.dtype}.")

    if pred.shape != target.shape:
        raise ValueError(
            "Expected `pred` and `target` to have the same shape."
            f" Got pred: {pred.shape} and target: {target.shape}.")

    if len(pred.shape) != 4 or len(target.shape) != 4:
        raise ValueError(
            "Expected `pred` and `target` to have BxCxHxW shape."
            f" Got pred: {pred.shape} and target: {target.shape}.")

    if len(kernel_size) != 2 or len(sigma) != 2:
        raise ValueError(
            "Expected `kernel_size` and `sigma` to have the length of two."
            f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}.")

    if any(x % 2 == 0 or x <= 0 for x in kernel_size):
        raise ValueError(
            f"Expected `kernel_size` to have odd positive number. Got {kernel_size}."
        )

    if any(y <= 0 for y in sigma):
        raise ValueError(
            f"Expected `sigma` to have positive number. Got {sigma}.")

    if data_range is None:
        data_range = max(pred.max() - pred.min(), target.max() - target.min())

    C1 = pow(k1 * data_range, 2)
    C2 = pow(k2 * data_range, 2)
    device = pred.device

    channel = pred.size(1)
    kernel = _gaussian_kernel(channel, kernel_size, sigma, device)

    # Concatenate
    # pred for mu_pred
    # target for mu_target
    # pred * pred for sigma_pred
    # target * target for sigma_target
    # pred * target for sigma_pred_target
    input_list = torch.cat(
        [pred, target, pred * pred, target * target,
         pred * target])  # (5 * B, C, H, W)
    outputs = F.conv2d(input_list, kernel, groups=channel)
    output_list = [
        outputs[x * pred.size(0):(x + 1) * pred.size(0)]
        for x in range(len(outputs))
    ]

    mu_pred_sq = output_list[0].pow(2)
    mu_target_sq = output_list[1].pow(2)
    mu_pred_target = output_list[0] * output_list[1]

    sigma_pred_sq = output_list[2] - mu_pred_sq
    sigma_target_sq = output_list[3] - mu_target_sq
    sigma_pred_target = output_list[4] - mu_pred_target

    UPPER = 2 * sigma_pred_target + C2
    LOWER = sigma_pred_sq + sigma_target_sq + C2

    ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / (
        (mu_pred_sq + mu_target_sq + C1) * LOWER)

    return reduce(ssim_idx, reduction)
Esempio n. 19
0
def ssim(pred: torch.Tensor,
         target: torch.Tensor,
         kernel_size: Sequence[int] = (11, 11),
         sigma: Sequence[float] = (1.5, 1.5),
         reduction: str = "elementwise_mean",
         data_range: float = None,
         k1: float = 0.01,
         k2: float = 0.03) -> torch.Tensor:
    """
    Computes Structual Similarity Index Measure

    Args:
        pred: Estimated image
        target: Ground truth image
        kernel_size: Size of the gaussian kernel. Default: (11, 11)
        sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
        reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean``

            Available reduction methods:
            - elementwise_mean: takes the mean
            - none: pass away
            - sum: add elements

        data_range: Range of the image. If ``None``, it is determined from the image (max - min)
        k1: Parameter of SSIM. Default: 0.01
        k2: Parameter of SSIM. Default: 0.03

    Returns:
        A Tensor with SSIM

    Example:

        >>> pred = torch.rand([16, 1, 16, 16])
        >>> target = pred * 1.25
        >>> ssim(pred, target)
        tensor(0.9520)
    """

    if pred.dtype != target.dtype:
        raise TypeError(
            "Expected `pred` and `target` to have the same data type."
            f" Got pred: {pred.dtype} and target: {target.dtype}.")

    if pred.shape != target.shape:
        raise ValueError(
            "Expected `pred` and `target` to have the same shape."
            f" Got pred: {pred.shape} and target: {target.shape}.")

    if len(pred.shape) != 4 or len(target.shape) != 4:
        raise ValueError(
            "Expected `pred` and `target` to have BxCxHxW shape."
            f" Got pred: {pred.shape} and target: {target.shape}.")

    if len(kernel_size) != 2 or len(sigma) != 2:
        raise ValueError(
            "Expected `kernel_size` and `sigma` to have the length of two."
            f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}.")

    if any(x % 2 == 0 or x <= 0 for x in kernel_size):
        raise ValueError(
            f"Expected `kernel_size` to have odd positive number. Got {kernel_size}."
        )

    if any(y <= 0 for y in sigma):
        raise ValueError(
            f"Expected `sigma` to have positive number. Got {sigma}.")

    if data_range is None:
        data_range = max(pred.max() - pred.min(), target.max() - target.min())

    C1 = pow(k1 * data_range, 2)
    C2 = pow(k2 * data_range, 2)
    device = pred.device

    channel = pred.size(1)
    kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
    mu_pred = F.conv2d(pred, kernel, groups=channel)
    mu_target = F.conv2d(target, kernel, groups=channel)

    mu_pred_sq = mu_pred.pow(2)
    mu_target_sq = mu_target.pow(2)
    mu_pred_target = mu_pred * mu_target

    sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq
    sigma_target_sq = F.conv2d(target * target, kernel,
                               groups=channel) - mu_target_sq
    sigma_pred_target = F.conv2d(pred * target, kernel,
                                 groups=channel) - mu_pred_target

    UPPER = 2 * sigma_pred_target + C2
    LOWER = sigma_pred_sq + sigma_target_sq + C2

    ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / (
        (mu_pred_sq + mu_target_sq + C1) * LOWER)

    return reduce(ssim_idx, reduction)