def do_binarization(
    input_data: torch.Tensor,
    bin_mode: str = "threshold",
    bin_threshold: Union[float, Sequence[float]] = 0.5,
) -> torch.Tensor:
    """
    Args:
        input_data: the input that to be binarized, in the shape [B] or [BN] or [BNHW] or [BNHWD].
        bin_mode: can be ``"threshold"`` or ``"mutually_exclusive"``, or a callable function.
            - ``"threshold"``, a single threshold or a sequence of thresholds should be set.
            - ``"mutually_exclusive"``, `input_data` will be converted by a combination of
            argmax and to_onehot.
        bin_threshold: the threshold to binarize the input data, can be a single value or a sequence of
            values that each one of the value represents a threshold for a class.

    Raises:
        AssertionError: when `bin_threshold` is a sequence and the input has the shape [B].
        AssertionError: when `bin_threshold` is a sequence but the length != the number of classes.
        AssertionError: when `bin_mode` is ``"mutually_exclusive"`` the input has the shape [B].
        AssertionError: when `bin_mode` is ``"mutually_exclusive"`` the input has the shape [B, 1].
    """
    input_ndim = input_data.ndimension()
    if bin_mode == "threshold":
        if isinstance(bin_threshold, Sequence):
            assert input_ndim > 1, "a sequence of thresholds are used for multi-class tasks."
            error_hint = "the length of the sequence should be the same as the number of classes."
            assert input_data.shape[1] == len(bin_threshold), "{}".format(error_hint)
            for cls_num in range(input_data.shape[1]):
                input_data[:, cls_num] = (input_data[:, cls_num] > bin_threshold[cls_num]).float()
        else:
            input_data = (input_data > bin_threshold).float()
    elif bin_mode == "mutually_exclusive":
        assert input_ndim > 1, "mutually_exclusive is used for multi-class tasks."
        n_classes = input_data.shape[1]
        assert n_classes > 1, "mutually_exclusive is used for multi-class tasks."
        input_data = torch.argmax(input_data, dim=1, keepdim=True)
        input_data = one_hot(input_data, num_classes=n_classes)
    return input_data
Пример #2
0
def compute_meandice(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    include_background: bool = True,
    to_onehot_y: bool = False,
    mutually_exclusive: bool = False,
    sigmoid: bool = False,
    other_act: Optional[Callable] = None,
    logit_thresh: float = 0.5,
) -> torch.Tensor:
    """Computes Dice score metric from full size Tensor and collects average.

    Args:
        y_pred: input data to compute, typical segmentation model output.
            it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32].
        y: ground truth to compute mean dice metric, the first dim is batch.
            example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32].
            alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly.
        include_background: whether to skip Dice computation on the first channel of
            the predicted output. Defaults to True.
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using
            a combination of argmax and to_onehot.  Defaults to False.
        sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False.
        other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``.
            for example: `other_act = torch.tanh`.
        logit_thresh: the threshold value used to convert (for example, after sigmoid if `sigmoid=True`)
            `y_pred` into a binary matrix. Defaults to 0.5.

    Raises:
        ValueError: When ``sigmoid=True`` and ``other_act is not None``. Incompatible values.
        TypeError: When ``other_act`` is not an ``Optional[Callable]``.
        ValueError: When ``sigmoid=True`` and ``mutually_exclusive=True``. Incompatible values.

    Returns:
        Dice scores per batch and per class, (shape [batch_size, n_classes]).

    Note:
        This method provides two options to convert `y_pred` into a binary matrix
            (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``,
            (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh``
                (optionally with a ``sigmoid`` function before thresholding).

    """
    n_classes = y_pred.shape[1]
    n_len = len(y_pred.shape)
    if sigmoid and other_act is not None:
        raise ValueError("Incompatible values: sigmoid=True and other_act is not None.")
    if sigmoid:
        y_pred = y_pred.float().sigmoid()

    if other_act is not None:
        if not callable(other_act):
            raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
        y_pred = other_act(y_pred)

    if n_classes == 1:
        if mutually_exclusive:
            warnings.warn("y_pred has only one class, mutually_exclusive=True ignored.")
        if to_onehot_y:
            warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.")
        if not include_background:
            warnings.warn("y_pred has only one channel, include_background=False ignored.")
        # make both y and y_pred binary
        y_pred = (y_pred >= logit_thresh).float()
        y = (y > 0).float()
    else:  # multi-channel y_pred
        # make both y and y_pred binary
        if mutually_exclusive:
            if sigmoid:
                raise ValueError("Incompatible values: sigmoid=True and mutually_exclusive=True.")
            y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
            y_pred = one_hot(y_pred, num_classes=n_classes)
        else:
            y_pred = (y_pred >= logit_thresh).float()
        if to_onehot_y:
            y = one_hot(y, num_classes=n_classes)

    if not include_background:
        y = y[:, 1:] if y.shape[1] > 1 else y
        y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

    assert y.shape == y_pred.shape, "Ground truth one-hot has differing shape (%r) from source (%r)" % (
        y.shape,
        y_pred.shape,
    )
    y = y.float()
    y_pred = y_pred.float()

    # reducing only spatial dimensions (not batch nor channels)
    reduce_axis = list(range(2, n_len))
    intersection = torch.sum(y * y_pred, dim=reduce_axis)

    y_o = torch.sum(y, reduce_axis)
    y_pred_o = torch.sum(y_pred, dim=reduce_axis)
    denominator = y_o + y_pred_o

    f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
    return f  # returns array of Dice shape: [batch, n_classes]
Пример #3
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(
                f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
            )

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        if self.batch:
            reduce_axis = [0] + reduce_axis
        intersection = torch.sum(target * input, reduce_axis)

        ground_o = torch.sum(target, reduce_axis)
        pred_o = torch.sum(input, reduce_axis)

        denominator = ground_o + pred_o

        w = self.w_func(ground_o.float())
        for b in w:
            infs = torch.isinf(b)
            b[infs] = 0.0
            b[infs] = torch.max(b)

        f: torch.Tensor = 1.0 - (2.0 * (
            intersection * w).sum(0 if self.batch else 1) + self.smooth_nr) / (
                (denominator * w).sum(0 if self.batch else 1) + self.smooth_dr)
        f = torch.log(torch.cosh(f))
        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            pass  # returns [N, n_classes] losses
        else:
            raise ValueError(
                f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
            )

        return f
Пример #4
0
def compute_confusion_metric(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    to_onehot_y: bool = False,
    activation: Optional[Union[str, Callable]] = None,
    bin_mode: Optional[str] = "threshold",
    bin_threshold: Union[float, Sequence[float]] = 0.5,
    metric_name: str = "hit_rate",
    average: Union[Average, str] = Average.MACRO,
    zero_division: int = 0,
) -> Union[np.ndarray, List[float], float]:
    """
    Compute confusion matrix related metrics. This function supports to calculate all metrics
    mentioned in: `Confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_.
    Before calculating, an activation function and/or a binarization manipulation can be employed
    to pre-process the original inputs. Zero division is handled by replacing the result into a
    single value. Referring to:
    `sklearn.metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_.

    Args:
        y_pred: predictions. As for classification tasks,
            `y_pred` should has the shape [B] or [BN]. As for segmentation tasks,
            the shape should be [BNHW] or [BNHWD].
        y: ground truth, the first dim is batch.
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        activation: [``"sigmoid"``, ``"softmax"``]
            Activation method, if specified, an activation function will be employed for `y_pred`.
            Defaults to None.
            The parameter can also be a callable function, for example:
            ``activation = lambda x: torch.log_softmax(x)``.
        bin_mode: [``"threshold"``, ``"mutually_exclusive"``]
            Binarization method, if specified, a binarization manipulation will be employed
            for `y_pred`.

            - ``"threshold"``, a single threshold or a sequence of thresholds should be set.
            - ``"mutually_exclusive"``, `y_pred` will be converted by a combination of `argmax` and `to_onehot`.
        bin_threshold: the threshold for binarization, can be a single value or a sequence of
            values that each one of the value represents a threshold for a class.
        metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``,
            ``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``,
            ``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``,
            ``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``,
            ``"informedness"``, ``"markedness"``]
            Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
            and you can also input those names instead.
        average: [``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``]
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.
        zero_division: the value to return when there is a zero division, for example, when all
            predictions and labels are negative. Defaults to 0.
    Raises:
        AssertionError: when data shapes of `y_pred` and `y` do not match.
        AssertionError: when specify activation function and ``mutually_exclusive`` mode at the same time.
    """

    y_pred_ndim, y_ndim = y_pred.ndimension(), y.ndimension()
    # one-hot for ground truth
    if to_onehot_y:
        if y_pred_ndim == 1:
            warnings.warn(
                "y_pred has only one channel, to_onehot_y=True ignored.")
        else:
            n_classes = y_pred.shape[1]
            y = one_hot(y, num_classes=n_classes)
    # check shape
    assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."
    # activation for predictions
    if activation is not None:
        assert bin_mode != "mutually_exclusive", "activation is unnecessary for mutually exclusive classes."
        y_pred = do_activation(y_pred, activation=activation)
    # binarization for predictions
    if bin_mode is not None:
        y_pred = do_binarization(y_pred,
                                 bin_mode=bin_mode,
                                 bin_threshold=bin_threshold)
    # get confusion matrix elements
    con_list = cal_confusion_matrix_elements(y_pred, y)
    # get simplified metric name
    metric_name = check_metric_name_and_unify(metric_name)
    result = do_calculate_metric(con_list,
                                 metric_name,
                                 average=average,
                                 zero_division=zero_division)
    return result
Пример #5
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(
                f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
            )

        p0 = input
        p1 = 1 - p0
        g0 = target
        g1 = 1 - g0

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        if self.batch:
            # reducing spatial dimensions and batch
            reduce_axis = [0] + reduce_axis

        tp = torch.sum(p0 * g0, reduce_axis)
        fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
        fn = self.beta * torch.sum(p1 * g0, reduce_axis)
        numerator = tp + self.smooth_nr
        denominator = tp + fp + fn + self.smooth_dr

        score: torch.Tensor = 1.0 - numerator / denominator

        if self.reduction == LossReduction.SUM.value:
            return torch.sum(score)  # sum over the batch and channel dims
        if self.reduction == LossReduction.NONE.value:
            return score  # returns [N, n_classes] losses
        if self.reduction == LossReduction.MEAN.value:
            return torch.mean(score)
        raise ValueError(
            f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
        )
Пример #6
0
def compute_roc_auc(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    to_onehot_y: bool = False,
    softmax: bool = False,
    other_act: Optional[Callable] = None,
    average: Union[Average, str] = Average.MACRO,
) -> Union[np.ndarray, List[float], float]:
    """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.

    Args:
        y_pred: input data to compute, typical classification model output.
            it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2].
        y: ground truth to compute ROC AUC metric, the first dim is batch.
            example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`).
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
        other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``.
            for example: `other_act = lambda x: torch.log_softmax(x)`.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    Raises:
        ValueError: When ``y_pred`` dimension is not one of [1, 2].
        ValueError: When ``y`` dimension is not one of [1, 2].
        ValueError: When ``softmax=True`` and ``other_act is not None``. Incompatible values.
        TypeError: When ``other_act`` is not an ``Optional[Callable]``.
        ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

    Note:
        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    """
    y_pred_ndim = y_pred.ndimension()
    y_ndim = y.ndimension()
    if y_pred_ndim not in (1, 2):
        raise ValueError(
            "Predictions should be of shape (batch_size, n_classes) or (batch_size, )."
        )
    if y_ndim not in (1, 2):
        raise ValueError(
            "Targets should be of shape (batch_size, n_classes) or (batch_size, )."
        )
    if y_pred_ndim == 2 and y_pred.shape[1] == 1:
        y_pred = y_pred.squeeze(dim=-1)
        y_pred_ndim = 1
    if y_ndim == 2 and y.shape[1] == 1:
        y = y.squeeze(dim=-1)

    if y_pred_ndim == 1:
        if to_onehot_y:
            warnings.warn(
                "y_pred has only one channel, to_onehot_y=True ignored.")
        if softmax:
            warnings.warn("y_pred has only one channel, softmax=True ignored.")
        return _calculate(y, y_pred)
    else:
        n_classes = y_pred.shape[1]
        if to_onehot_y:
            y = one_hot(y, n_classes)
        if softmax and other_act is not None:
            raise ValueError(
                "Incompatible values: softmax=True and other_act is not None.")
        if softmax:
            y_pred = y_pred.float().softmax(dim=1)
        if other_act is not None:
            if not callable(other_act):
                raise TypeError(
                    f"other_act must be None or callable but is {type(other_act).__name__}."
                )
            y_pred = other_act(y_pred)

        assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."

        average = Average(average)
        if average == Average.MICRO:
            return _calculate(y.flatten(), y_pred.flatten())
        else:
            y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
            auc_values = [
                _calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)
            ]
            if average == Average.NONE:
                return auc_values
            if average == Average.MACRO:
                return np.mean(auc_values)
            if average == Average.WEIGHTED:
                weights = [sum(y_) for y_ in y]
                return np.average(auc_values, weights=weights)
            raise ValueError(
                f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].'
            )