Exemple #1
0
def compute_meandice(y_pred,
                     y,
                     include_background=False,
                     to_onehot_y=True,
                     mutually_exclusive=True,
                     add_sigmoid=False,
                     logit_thresh=None):
    """Computes dice score metric from full size Tensor and collects average.

    Args:
        y_pred (torch.Tensor): input data to compute, typical segmentation model output.
                               it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32].
        y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch.
        include_background (Bool): whether to skip dice computation on the first channel of the predicted output.
        to_onehot_y (Bool): whether to convert `y` into the one-hot format.
        mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using
            a combination of argmax and to_onehot.
        add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation.
        logit_thresh (Float): the threshold value used to convert `y_pred` into a binary matrix.

    Note:
        This method provide two options to convert `y_pred` into a binary matrix:
            (1) when `mutually_exclusive` is True, it uses a combination of argmax and to_onehot
            (2) when `mutually_exclusive` is False, it uses a threshold `logit_thresh`
                (optionally with a sigmoid function before thresholding).

    """
    n_channels_y_pred = y_pred.shape[1]

    if mutually_exclusive:
        if logit_thresh is not None:
            raise ValueError(
                '`logit_thresh` is incompatible when mutually_exclusive is True.'
            )
        y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
        y_pred = one_hot(y_pred, n_channels_y_pred)
    else:  # channel-wise thresholding
        if add_sigmoid:
            y_pred = torch.sigmoid(y_pred)
        if logit_thresh is not None:
            y_pred = (y_pred >= logit_thresh).float()

    if to_onehot_y:
        y = one_hot(y, n_channels_y_pred)

    if not include_background:
        y = y[:, 1:] if y.shape[1] > 1 else y
        y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

    # reducing only spatial dimensions (not batch nor channels)
    reduce_axis = list(range(2, y_pred.dim()))
    intersection = torch.sum(y * y_pred, reduce_axis)

    y_o = torch.sum(y, reduce_axis)
    y_pred_o = torch.sum(y_pred, reduce_axis)
    denominator = y_o + y_pred_o

    f = (2.0 * intersection) / denominator
    # final reduce_mean across batches and channels
    return torch.mean(f)
Exemple #2
0
    def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5):
        """
        Args:
            input (tensor): the shape should be BNH[WD].
            target (tensor): the shape should be BNH[WD].
            smooth: a small constant to avoid nan.
        """
        if self.sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if n_pred_ch == 1:
            if self.softmax:
                warnings.warn("single channel prediction, `softmax=True` ignored.")
            if self.to_onehot_y:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            if not self.include_background:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
        else:
            if self.softmax:
                input = torch.softmax(input, 1)
            if self.to_onehot_y:
                target = one_hot(target, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]
        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        intersection = torch.sum(target * input, reduce_axis)

        ground_o = torch.sum(target, reduce_axis)
        pred_o = torch.sum(input, reduce_axis)

        denominator = ground_o + pred_o

        w = self.w_func(ground_o.float())
        for b in w:
            infs = torch.isinf(b)
            b[infs] = 0.0
            b[infs] = torch.max(b)

        f = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth)

        if self.reduction == "mean":
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == "sum":
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == "none":
            pass  # returns [N, n_classes] losses
        else:
            raise ValueError(f"reduction={self.reduction} is invalid.")

        return f
Exemple #3
0
    def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5):
        """
        Args:
            input (tensor): the shape should be BNH[WD].
            target (tensor): the shape should be BNH[WD].
            smooth (float): a small constant to avoid nan.
        """
        if self.do_sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if n_pred_ch == 1:
            if self.do_softmax:
                warnings.warn("single channel prediction, `do_softmax=True` ignored.")
            if self.to_onehot_y:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            if not self.include_background:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
        else:
            if self.do_softmax:
                input = torch.softmax(input, 1)
            if self.to_onehot_y:
                target = one_hot(target, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]
        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        p0 = input
        p1 = 1 - p0
        g0 = target
        g1 = 1 - g0

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))

        tp = torch.sum(p0 * g0, reduce_axis)
        fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
        fn = self.beta * torch.sum(p1 * g0, reduce_axis)

        numerator = tp + smooth
        denominator = tp + fp + fn + smooth

        score = 1.0 - numerator / denominator

        if self.reduction == "sum":
            return score.sum()  # sum over the batch and channel dims
        if self.reduction == "none":
            return score  # returns [N, n_classes] losses
        if self.reduction == "mean":
            return score.mean()
        raise ValueError(f"reduction={self.reduction} is invalid.")
Exemple #4
0
    def test_shape(self, input_data, expected_shape, expected_result=None):
        result = one_hot(**input_data)
        self.assertEqual(result.shape, expected_shape)
        if expected_result is not None:
            self.assertTrue(np.allclose(expected_result, result.numpy()))

        if "dtype" in input_data:
            self.assertEqual(result.dtype, input_data["dtype"])
        else:
            # by default, expecting float type
            self.assertEqual(result.dtype, torch.float)
Exemple #5
0
    def forward(self, input, target, smooth=1e-5):
        """
        Args:
            input (tensor): the shape should be BNH[WD].
            target (tensor): the shape should be BNH[WD].
            smooth (float): a small constant to avoid nan.
        """
        if self.do_sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if n_pred_ch == 1:
            if self.do_softmax:
                warnings.warn("single channel prediction, `do_softmax=True` ignored.")
            if self.to_onehot_y:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            if not self.include_background:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
        else:
            if self.do_softmax:
                input = torch.softmax(input, 1)
            if self.to_onehot_y:
                target = one_hot(target, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]
        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        intersection = torch.sum(target * input, reduce_axis)

        if self.squared_pred:
            target = torch.pow(target, 2)
            input = torch.pow(input, 2)

        ground_o = torch.sum(target, reduce_axis)
        pred_o = torch.sum(input, reduce_axis)

        denominator = ground_o + pred_o

        if self.jaccard:
            denominator -= intersection

        f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
        if self.reduction == "sum":
            return f.sum()  # sum over the batch and channel dims
        if self.reduction == "none":
            return f  # returns [N, n_classes] losses
        if self.reduction == "mean":
            return f.mean()  # the batch and channel average
        raise ValueError(f"reduction={self.reduction} is invalid.")
Exemple #6
0
    def forward(self, pred, ground, smooth=1e-5):
        """
        Args:
            pred (tensor): the shape should be BNH[WD].
            ground (tensor): the shape should be BNH[WD].
            smooth (float): a small constant to avoid nan.
        """
        if self.do_sigmoid:
            pred = torch.sigmoid(pred)
        n_pred_ch = pred.shape[1]
        if n_pred_ch == 1:
            if self.do_softmax:
                warnings.warn(
                    "single channel prediction, `do_softmax=True` ignored.")
            if self.to_onehot_y:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            if not self.include_background:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
        else:
            if self.do_softmax:
                pred = torch.softmax(pred, 1)
            if self.to_onehot_y:
                ground = one_hot(ground, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                ground = ground[:, 1:]
                pred = pred[:, 1:]
                assert ground.shape == pred.shape, "ground truth one-hot has differing shape (%r) from pred (%r)" % (
                    ground.shape,
                    pred.shape,
                )

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(pred.shape)))
        intersection = torch.sum(ground * pred, reduce_axis)

        if self.squared_pred:
            ground = torch.pow(ground, 2)
            pred = torch.pow(pred, 2)

        ground_o = torch.sum(ground, reduce_axis)
        pred_o = torch.sum(pred, reduce_axis)

        denominator = ground_o + pred_o

        if self.jaccard:
            denominator -= intersection

        f = (2.0 * intersection + smooth) / (denominator + smooth)
        return 1.0 - f.mean()  # final reduce_mean across batches and channels
Exemple #7
0
    def forward(self, pred, ground, smooth=1e-5):
        """
        Args:
            pred (tensor): the shape should be BNH[WD].
            ground (tensor): the shape should be B1H[WD].
            smooth (float): a small constant to avoid nan.
        """
        if ground.shape[1] != 1:
            raise ValueError(
                "Ground truth should have only a single channel, shape is " +
                str(ground.shape))

        psum = pred.float()
        if self.do_sigmoid:
            psum = psum.sigmoid()  # use sigmoid activation
        if pred.shape[1] == 1:
            if self.do_softmax:
                raise ValueError(
                    'do_softmax is not compatible with single channel prediction.'
                )
            if not self.include_background:
                warnings.warn(
                    'single channel prediction, `include_background=False` ignored.'
                )
            tsum = ground
        else:  # multiclass dice loss
            if self.do_softmax:
                psum = torch.softmax(pred, 1)
            tsum = one_hot(ground, pred.shape[1])  # B1HW(D) -> BNHW(D)
            # exclude background category so that it doesn't overwhelm the other segmentations if they are small
            if not self.include_background:
                tsum = tsum[:, 1:]
                psum = psum[:, 1:]
        assert tsum.shape == psum.shape, (
            "Ground truth one-hot has differing shape (%r) from source (%r)" %
            (tsum.shape, psum.shape))

        batchsize, n_classes = tsum.shape[:2]
        tsum = tsum.float().view(batchsize, n_classes, -1)
        psum = psum.view(batchsize, n_classes, -1)

        intersection = psum * tsum
        sums = psum + tsum

        w = self.w_func(tsum.sum(2))
        for b in w:
            infs = torch.isinf(b)
            b[infs] = 0.0
            b[infs] = torch.max(b)

        score = (2.0 * intersection.sum(2) * w + smooth) / (sums.sum(2) * w +
                                                            smooth)
        return 1 - score.mean()
Exemple #8
0
    def forward(self, pred, ground, smooth=1e-5):
        """
        Args:
            pred (tensor): the shape should be BNH[WD].
            ground (tensor): the shape should be BNH[WD].
            smooth (float): a small constant to avoid nan.
        """
        if self.do_sigmoid:
            pred = torch.sigmoid(pred)
        n_pred_ch = pred.shape[1]
        if n_pred_ch == 1:
            if self.do_softmax:
                warnings.warn(
                    'single channel prediction, `do_softmax=True` ignored.')
            if self.to_onehot_y:
                warnings.warn(
                    'single channel prediction, `to_onehot_y=True` ignored.')
            if not self.include_background:
                warnings.warn(
                    'single channel prediction, `include_background=False` ignored.'
                )
        else:
            if self.do_softmax:
                pred = torch.softmax(pred, 1)
            if self.to_onehot_y:
                ground = one_hot(ground, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                ground = ground[:, 1:]
                pred = pred[:, 1:]
                assert ground.shape == pred.shape, (
                    'ground truth one-hot has differing shape (%r) from pred (%r)'
                    % (ground.shape, pred.shape))

        p0 = pred
        p1 = 1 - p0
        g0 = ground
        g1 = 1 - g0

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(pred.shape)))

        tp = torch.sum(p0 * g0, reduce_axis)
        fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
        fn = self.beta * torch.sum(p1 * g0, reduce_axis)

        numerator = tp + smooth
        denominator = tp + fp + fn + smooth

        score = numerator / denominator

        return 1.0 - score.mean()
Exemple #9
0
    def __call__(self, img, to_onehot=None, num_classes=None):
        if to_onehot or self.to_onehot:
            if num_classes is None:
                num_classes = self.num_classes
            assert isinstance(num_classes,
                              int), "must specify class number for One-Hot."
            img = one_hot(img, num_classes)
        n_classes = img.shape[1]
        outputs = list()
        for i in range(n_classes):
            outputs.append(img[:, i:i + 1])

        return outputs
Exemple #10
0
    def __call__(self,
                 img,
                 to_onehot: Optional[bool] = None,
                 num_classes: Optional[int] = None
                 ):  # type: ignore # see issue #495
        if to_onehot or self.to_onehot:
            if num_classes is None:
                num_classes = self.num_classes
            assert isinstance(num_classes,
                              int), "must specify class number for One-Hot."
            img = one_hot(img, num_classes)
        n_classes = img.shape[1]
        outputs = list()
        for i in range(n_classes):
            outputs.append(img[:, i:i + 1])

        return outputs
Exemple #11
0
    def forward(self, pred, ground, smooth=1e-5):
        if ground.shape[1] != 1:
            raise ValueError(
                "Ground truth should have only a single channel, shape is " +
                str(ground.shape))

        psum = pred.float()
        if self.do_sigmoid:
            psum = psum.sigmoid()  # use sigmoid activation
        if pred.shape[1] == 1:
            if self.do_softmax:
                raise ValueError(
                    'do_softmax is not compatible with single channel prediction.'
                )
            if not self.include_background:
                raise RuntimeWarning(
                    'single channel prediction, `include_background=False` ignored.'
                )
            tsum = ground
        else:  # multiclass dice loss
            if self.do_softmax:
                if self.do_sigmoid:
                    raise ValueError(
                        'do_sigmoid=True and do_softmax=Ture are not compatible.'
                    )
                psum = torch.softmax(pred, 1)
            tsum = one_hot(ground, pred.shape[1])  # B1HW(D) -> BNHW(D)
            # exclude background category so that it doesn't overwhelm the other segmentations if they are small
            if not self.include_background:
                tsum = tsum[:, 1:]
                psum = psum[:, 1:]
        assert tsum.shape == psum.shape, (
            "Ground truth one-hot has differing shape (%r) from source (%r)" %
            (tsum.shape, psum.shape))

        batchsize = ground.size(0)
        tsum = tsum.float().view(batchsize, -1)
        psum = psum.view(batchsize, -1)

        intersection = psum * tsum
        sums = psum + tsum

        score = 2.0 * (intersection.sum(1) + smooth) / (sums.sum(1) + smooth)
        return 1 - score.sum() / batchsize
Exemple #12
0
    def __call__(self,
                 img,
                 argmax=None,
                 to_onehot=None,
                 n_classes=None,
                 threshold_values=None,
                 logit_thresh=None):
        if argmax or self.argmax:
            img = torch.argmax(img, dim=1, keepdim=True)

        if to_onehot or self.to_onehot:
            img = one_hot(img,
                          self.n_classes if n_classes is None else n_classes)

        if threshold_values or self.threshold_values:
            img = img >= (self.logit_thresh
                          if logit_thresh is None else logit_thresh)

        return img.float()
Exemple #13
0
    def __call__(  # type: ignore # see issue #495
        self,
        img,
        argmax: Optional[bool] = None,
        to_onehot: Optional[bool] = None,
        n_classes: Optional[int] = None,
        threshold_values: Optional[bool] = None,
        logit_thresh: Optional[float] = None,
    ):
        if argmax or self.argmax:
            img = torch.argmax(img, dim=1, keepdim=True)

        if to_onehot or self.to_onehot:
            _nclasses = self.n_classes if n_classes is None else n_classes
            assert isinstance(
                _nclasses,
                int), "One of self.n_classes or n_classes must be an integer"
            img = one_hot(img, _nclasses)

        if threshold_values or self.threshold_values:
            img = img >= (self.logit_thresh
                          if logit_thresh is None else logit_thresh)

        return img.float()
Exemple #14
0
def compute_roc_auc(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    to_onehot_y: bool = False,
    softmax: bool = False,
    average: Optional[str] = "macro",
):
    """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.

    Args:
        y_pred (torch.Tensor): input data to compute, typical classification model output.
            it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2].
        y (torch.Tensor): ground truth to compute ROC AUC metric, the first dim is batch.
            example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`).
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
        average (`macro|weighted|micro|None`): type of averaging performed if not binary
            classification. Default is 'macro'.

            - 'macro': calculate metrics for each label, and find their unweighted mean.
              this does not take label imbalance into account.
            - 'weighted': calculate metrics for each label, and find their average,
              weighted by support (the number of true instances for each label).
            - 'micro': calculate metrics globally by considering each element of the label
              indicator matrix as a label.
            - None: the scores for each class are returned.

    Note:
        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    """
    y_pred_ndim = y_pred.ndimension()
    y_ndim = y.ndimension()
    if y_pred_ndim not in (1, 2):
        raise ValueError(
            "predictions should be of shape (batch_size, n_classes) or (batch_size, )."
        )
    if y_ndim not in (1, 2):
        raise ValueError(
            "targets should be of shape (batch_size, n_classes) or (batch_size, )."
        )
    if y_pred_ndim == 2 and y_pred.shape[1] == 1:
        y_pred = y_pred.squeeze(dim=-1)
        y_pred_ndim = 1
    if y_ndim == 2 and y.shape[1] == 1:
        y = y.squeeze(dim=-1)

    if y_pred_ndim == 1:
        if to_onehot_y:
            warnings.warn(
                "y_pred has only one channel, to_onehot_y=True ignored.")
        if softmax:
            warnings.warn("y_pred has only one channel, softmax=True ignored.")
        return _calculate(y, y_pred)
    else:
        n_classes = y_pred.shape[1]
        if to_onehot_y:
            y = one_hot(y, n_classes)
        if softmax:
            y_pred = y_pred.float().softmax(dim=1)

        assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."

        if average == "micro":
            return _calculate(y.flatten(), y_pred.flatten())
        else:
            y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
            auc_values = [
                _calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)
            ]
            if average is None:
                return auc_values
            if average == "macro":
                return np.mean(auc_values)
            if average == "weighted":
                weights = [sum(y_) for y_ in y]
                return np.average(auc_values, weights=weights)
            raise ValueError("unsupported average method.")
Exemple #15
0
def compute_meandice(y_pred,
                     y,
                     include_background=True,
                     to_onehot_y=True,
                     mutually_exclusive=False,
                     add_sigmoid=False,
                     logit_thresh=0.5):
    """Computes dice score metric from full size Tensor and collects average.

    Args:
        y_pred (torch.Tensor): input data to compute, typical segmentation model output.
            it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32].
        y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch.
            example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32].
            alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly.
        include_background (Bool): whether to skip Dice computation on the first channel of
            the predicted output. Defaults to True.
        to_onehot_y (Bool): whether to convert `y` into the one-hot format. Defaults to True.
        mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using
            a combination of argmax and to_onehot.  Defaults to False.
        add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. Defaults to False.
        logit_thresh (Float): the threshold value used to convert (after sigmoid if `add_sigmoid=True`)
            `y_pred` into a binary matrix. Defaults to 0.5.

    Returns:
        Dice scores per batch and per class (shape: [batch_size, n_classes]).

    Note:
        This method provides two options to convert `y_pred` into a binary matrix
            (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``,
            (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh``
                (optionally with a ``sigmoid`` function before thresholding).

    """
    n_classes = y_pred.shape[1]
    n_len = len(y_pred.shape)

    if add_sigmoid:
        y_pred = y_pred.float().sigmoid()

    if n_classes == 1:
        if mutually_exclusive:
            warnings.warn('y_pred has only one class, mutually_exclusive=True ignored.')
        if to_onehot_y:
            warnings.warn('y_pred has only one channel, to_onehot_y=True ignored.')
        if not include_background:
            warnings.warn('y_pred has only one channel, include_background=False ignored.')
        # make both y and y_pred binary
        y_pred = (y_pred >= logit_thresh).float()
        y = (y > 0).float()
    else:  # multi-channel y_pred
        # make both y and y_pred binary
        if mutually_exclusive:
            if add_sigmoid:
                raise ValueError('add_sigmoid=True is incompatible with mutually_exclusive=True.')
            y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
            y_pred = one_hot(y_pred, n_classes)
        else:
            y_pred = (y_pred >= logit_thresh).float()
        if to_onehot_y:
            y = one_hot(y, n_classes)

    if not include_background:
        y = y[:, 1:] if y.shape[1] > 1 else y
        y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

    assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
                                     (y.shape, y_pred.shape))

    # reducing only spatial dimensions (not batch nor channels)
    reduce_axis = list(range(2, n_len))
    intersection = torch.sum(y * y_pred, reduce_axis)

    y_o = torch.sum(y, reduce_axis)
    y_pred_o = torch.sum(y_pred, reduce_axis)
    denominator = y_o + y_pred_o

    f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float('nan')).to(y_o.float()))
    return f  # returns array of Dice shape: [Batch, n_classes]
Exemple #16
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD]
        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        if self.batch_version:
            # reducing only spatial dimensions and batch (not channels)
            reduce_axis = [0] + list(range(2, len(input.shape)))
        else:
            # reducing only spatial dimensions (not batch nor channels)
            reduce_axis = list(range(2, len(input.shape)))
        intersection = torch.sum(target * input, dim=reduce_axis)

        if self.squared_pred:
            target = torch.pow(target, 2)
            input = torch.pow(input, 2)

        ground_o = torch.sum(target, dim=reduce_axis)
        pred_o = torch.sum(input, dim=reduce_axis)

        denominator = ground_o + pred_o

        if self.jaccard:
            denominator = 2.0 * (denominator - intersection)

        f: torch.Tensor = (1.0 - (2.0 * intersection + self.smooth_num) /
                           (denominator + self.smooth_den))**self.pow

        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            pass  # returns [N, n_classes] losses or [n_classes] if batch version
        else:
            raise ValueError(
                f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
            )

        return f
Exemple #17
0
    def _test_epoch(self, epoch):
        start = time.time()
        logs = {}
        loss_meter = AverageMetricTracker()
        metric_trackers = {m[1]: AverageMetricTracker() for m in self.metrics}

        self.model.eval()

        for batch_idx, (x, y) in enumerate(self.test_loader):
            x = x.to(self.device)
            y = y.to(self.device)

            with torch.no_grad():
                # make predictions
                out1, out2 = self.model.forward(x)
                classes = class_count(y, self.num_classes)

                # calculate weighted loss
                clf_loss = self.criterion(out2, classes)
                sgm_loss = self.loss(out1, y)
                loss = ((1.0 - self.loss_weight) * clf_loss) + (self.loss_weight * sgm_loss)

            loss_value = loss.cpu().detach().numpy()
            loss_meter.add(loss_value)
            loss_logs = {'loss': loss_meter.mean}
            logs.update(loss_logs)

            # neptune logging (valid step)
            neptune.log_metric('test_loss_step', loss_value)

            for m in self.metrics:
                metric = m[0]  # unpack metric class
                metric_name = m[1]  # unpack metric name
                metric_type = m[2]  # unpack metric type

                if metric_type == 'classification':
                    metric_value = metric(out2, classes).cpu().detach().numpy()
                    metric_trackers[metric_name].add(metric_value)

                elif metric_type == 'segmentation':
                    d = out1.get_device() if out1.is_cuda else 'cpu'
                    if not self.binary:
                        y_pred = torch.argmax(out1, dim=1, keepdim=True).to(d)
                        y_pred = one_hot(y_pred, num_classes=out1.shape[1])
                    else:
                        y_pred = torch.round(out1).to(d)

                    # calculate metric
                    metric_value = metric(y_pred, y)
                    metric_value = metric_value[0][0] if isinstance(metric_value, tuple) else metric_value
                    metric_value = metric_value.cpu().detach().numpy()
                    metric_trackers[metric_name].add(metric_value)

                else:
                    raise ValueError(f'Type {metric_type} is not a valid metric type.')

                # neptune logging (valid step)
                neptune.log_metric('test_' + metric_name + '_step', metric_value)

            metrics_logs = {k: v.mean for k, v in metric_trackers.items()}
            logs.update(metrics_logs)

        # neptune logging (valid epoch)
        for k, v in logs.items():
            neptune.log_metric('test_' + k + '_epoch', v)

        duration = time.time() - start

        if self.verbose:
            self._show_progress(duration, logs, stage='Test')

        return logs
Exemple #18
0
 def test_shape(self, input_data, expected_shape, expected_result=None):
     result = one_hot(**input_data)
     self.assertEqual(result.shape, expected_shape)
     if expected_result is not None:
         self.assertTrue(np.allclose(expected_result, result.numpy()))
Exemple #19
0
    def forward(self, pred, gt):
        """
        Input:
            - pred: the output from model (before softmax)
                    shape (N, C, H, W)
            - gt: ground truth map
                    shape (N, 1, H, w)
        Return:
            - boundary loss, averaged over mini-batch
        """

        n, c, _, _ = pred.shape

        # softmax so that predicted map can be distributed in [0, 1]
        pred = torch.softmax(pred, dim=1)

        # one-hot vector of ground truth
        one_hot_gt = one_hot(gt, c)

        # boundary map
        gt_b = F.max_pool2d(1 - one_hot_gt,
                            kernel_size=self.theta0,
                            stride=1,
                            padding=(self.theta0 - 1) // 2)
        gt_b -= 1 - one_hot_gt

        pred_b = F.max_pool2d(1 - pred,
                              kernel_size=self.theta0,
                              stride=1,
                              padding=(self.theta0 - 1) // 2)
        pred_b -= 1 - pred

        # extended boundary map
        gt_b_ext = F.max_pool2d(gt_b,
                                kernel_size=self.theta,
                                stride=1,
                                padding=(self.theta - 1) // 2)
        pred_b_ext = F.max_pool2d(pred_b,
                                  kernel_size=self.theta,
                                  stride=1,
                                  padding=(self.theta - 1) // 2)

        #         # to check hyper-parameter
        #         idx= 0
        #         print('boundary_loss')
        #         print(torch.unique(gt_b),torch.unique(gt_b_ext))
        #         plt.figure(figsize=(24,8))
        #         plt.subplot(231);plt.title('gt');plt.imshow(gt[idx,0].cpu().detach().numpy())
        #         plt.subplot(232);plt.title('gt_boundary');plt.imshow(gt_b[idx,0].cpu().detach().numpy())
        #         plt.subplot(233);plt.title('gt_boundary_ext');plt.imshow(gt_b_ext[0,idx].cpu().detach().numpy())
        #         plt.subplot(234);plt.title('pred');plt.imshow(pred[idx,1].cpu().detach().numpy())
        #         plt.subplot(235);plt.title('pred_boundary');plt.imshow(pred_b[idx,0].cpu().detach().numpy())
        #         plt.subplot(236);plt.title('pred_boundary_ext');plt.imshow(pred_b_ext[idx,0].cpu().detach().numpy())
        #         plt.show()

        # reshape
        gt_b = gt_b.view(n, c, -1)
        pred_b = pred_b.view(n, c, -1)
        gt_b_ext = gt_b_ext.view(n, c, -1)
        pred_b_ext = pred_b_ext.view(n, c, -1)

        smooth = 1e-7
        #         original impliment
        # Precision, Recall
        P = torch.sum(pred_b * gt_b_ext,
                      dim=2) / (torch.sum(pred_b, dim=2) + smooth)
        R = torch.sum(pred_b_ext * gt_b,
                      dim=2) / (torch.sum(gt_b, dim=2) + smooth)

        # Boundary F1 Score
        smooth = 1e-7
        BF1 = (2 * P * R) / (P + R + smooth)
        #         BF1 = (2 * self.alpha * (1-self.alpha) * P * R + smooth) / (self.alpha*P + (1-self.alpha)*R + smooth)
        # summing BF1 Score for each class and average over mini-batch
        #         loss = torch.mean(1 - BF1)
        loss = torch.mean(torch.pow(1 - BF1, self.gamma))

        return loss