Esempio n. 1
0
    def forward(self, y_pred: torch.Tensor,
                y_true: torch.Tensor) -> torch.Tensor:
        """
        Args:
            y_pred : the shape should be BNH[WD], where N is the number of classes.
                It only supports binary segmentation.
                The input should be the original logits since it will be transformed by
                    a sigmoid in the forward function.
            y_true : the shape should be BNH[WD], where N is the number of classes.
                It only supports binary segmentation.

        Raises:
            ValueError: When input and target are different shape
            ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
            ValueError: When num_classes
            ValueError: When the number of classes entered does not match the expected number
        """
        if y_pred.shape != y_true.shape:
            raise ValueError(
                f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})"
            )

        if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
            raise ValueError(
                f"input shape must be 4 or 5, but got {y_pred.shape}")

        if y_pred.shape[1] == 1:
            y_pred = one_hot(y_pred, num_classes=self.num_classes)
            y_true = one_hot(y_true, num_classes=self.num_classes)

        if torch.max(y_true) != self.num_classes - 1:
            raise ValueError(
                f"Pelase make sure the number of classes is {self.num_classes-1}"
            )

        n_pred_ch = y_pred.shape[1]
        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                y_true = one_hot(y_true, num_classes=n_pred_ch)

        asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
        asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

        loss: torch.Tensor = self.weight * asy_focal_loss + (
            1 - self.weight) * asy_focal_tversky_loss

        if self.reduction == LossReduction.SUM.value:
            return torch.sum(loss)  # sum over the batch and channel dims
        if self.reduction == LossReduction.NONE.value:
            return loss  # returns [N, num_classes] losses
        if self.reduction == LossReduction.MEAN.value:
            return torch.mean(loss)
        raise ValueError(
            f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
        )
Esempio n. 2
0
    def __call__(
        self, img: Union[Sequence[NdarrayOrTensor],
                         NdarrayOrTensor]) -> NdarrayOrTensor:
        img_ = self.get_stacked_torch(img)

        if self.num_classes is not None:
            has_ch_dim = True
            if img_.ndimension() > 1 and img_.shape[1] > 1:
                warnings.warn(
                    "no need to specify num_classes for One-Hot format data.")
            else:
                if img_.ndimension() == 1:
                    # if no channel dim, need to remove channel dim after voting
                    has_ch_dim = False
                img_ = one_hot(img_, self.num_classes, dim=1)

        img_ = torch.mean(img_.float(), dim=0)

        if self.num_classes is not None:
            # if not One-Hot, use "argmax" to vote the most common class
            out_pt = torch.argmax(img_, dim=0, keepdim=has_ch_dim)
        else:
            # for One-Hot data, round the float number to 0 or 1
            out_pt = torch.round(img_)
        return self.post_convert(out_pt, img)
Esempio n. 3
0
    def forward(self, y_pred: torch.Tensor,
                y_true: torch.Tensor) -> torch.Tensor:
        n_pred_ch = y_pred.shape[1]

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                y_true = one_hot(y_true, num_classes=n_pred_ch)

        if y_true.shape != y_pred.shape:
            raise ValueError(
                f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})"
            )

        y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
        cross_entropy = -y_true * torch.log(y_pred)

        back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
        back_ce = (1 - self.delta) * back_ce

        fore_ce = cross_entropy[:, 1]
        fore_ce = self.delta * fore_ce

        loss = torch.mean(
            torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
        return loss
Esempio n. 4
0
    def __call__(
            self, img: Union[Sequence[torch.Tensor],
                             torch.Tensor]) -> torch.Tensor:
        img_ = torch.stack(img) if isinstance(img,
                                              (tuple,
                                               list)) else torch.as_tensor(img)
        if self.num_classes is not None:
            has_ch_dim = True
            if img_.ndimension() > 2 and img_.shape[2] > 1:
                warnings.warn(
                    "no need to specify num_classes for One-Hot format data.")
            else:
                if img_.ndimension() == 2:
                    # if no channel dim, need to remove channel dim after voting
                    has_ch_dim = False
                img_ = one_hot(img_, self.num_classes, dim=2)

        img_ = torch.mean(img_.float(), dim=0)

        if self.num_classes is not None:
            # if not One-Hot, use "argmax" to vote the most common class
            return torch.argmax(img_, dim=1, keepdim=has_ch_dim)
        else:
            # for One-Hot data, round the float number to 0 or 1
            return torch.round(img_)
Esempio n. 5
0
    def __call__(
        self,
        img,
        argmax: Optional[bool] = None,
        to_onehot: Optional[bool] = None,
        n_classes: Optional[int] = None,
        threshold_values: Optional[bool] = None,
        logit_thresh: Optional[float] = None,
    ):
        """
        Args:
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: whether to convert input data into the one-hot format.
                Defaults to ``self.to_onehot``.
            n_classes: the number of classes to convert to One-Hot format.
                Defaults to ``self.n_classes``.
            threshold_values: whether threshold the float value to int number 0 or 1.
                Defaults to ``self.threshold_values``.
            logit_thresh: the threshold value for thresholding operation..
                Defaults to ``self.logit_thresh``.

        """
        if argmax or self.argmax:
            img = torch.argmax(img, dim=1, keepdim=True)

        if to_onehot or self.to_onehot:
            _nclasses = self.n_classes if n_classes is None else n_classes
            assert isinstance(_nclasses, int), "One of self.n_classes or n_classes must be an integer"
            img = one_hot(img, _nclasses)

        if threshold_values or self.threshold_values:
            img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh)

        return img.float()
Esempio n. 6
0
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD]. The input should be the original logits
                due to the restriction of ``monai.losses.FocalLoss``.
            target: the shape should be BNH[WD] or B1H[WD].

        Raises:
            ValueError: When number of dimensions for input and target are different.
            ValueError: When number of channels for target is neither 1 nor the same as input.

        """
        if len(input.shape) != len(target.shape):
            raise ValueError("the number of dimensions for input and target should be the same.")

        n_pred_ch = input.shape[1]

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        dice_loss = self.dice(input, target)
        focal_loss = self.focal(input, target)
        total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
        return total_loss
Esempio n. 7
0
    def test_shape(self, input_data, expected_shape, expected_result=None):
        result = one_hot(**input_data)
        self.assertEqual(result.shape, expected_shape)
        if expected_result is not None:
            self.assertTrue(np.allclose(expected_result, result.numpy()))

        if "dtype" in input_data:
            self.assertEqual(result.dtype, input_data["dtype"])
        else:
            # by default, expecting float type
            self.assertEqual(result.dtype, torch.float)
Esempio n. 8
0
    def __call__(
        self,
        img: torch.Tensor,
        argmax: Optional[bool] = None,
        to_onehot: Optional[bool] = None,
        num_classes: Optional[int] = None,
        threshold_values: Optional[bool] = None,
        logit_thresh: Optional[float] = None,
        rounding: Optional[str] = None,
        n_classes: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: whether to convert input data into the one-hot format.
                Defaults to ``self.to_onehot``.
            num_classes: the number of classes to convert to One-Hot format.
                Defaults to ``self.num_classes``.
            threshold_values: whether threshold the float value to int number 0 or 1.
                Defaults to ``self.threshold_values``.
            logit_thresh: the threshold value for thresholding operation..
                Defaults to ``self.logit_thresh``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``num_classes`` instead.

        """
        # in case the new num_classes is default but you still call deprecated n_classes
        if n_classes is not None and num_classes is None:
            num_classes = n_classes
        if argmax or self.argmax:
            img = torch.argmax(img, dim=0, keepdim=True)

        if to_onehot or self.to_onehot:
            _nclasses = self.num_classes if num_classes is None else num_classes
            if not isinstance(_nclasses, int):
                raise AssertionError("One of self.num_classes or num_classes must be an integer")
            img = one_hot(img, num_classes=_nclasses, dim=0)

        if threshold_values or self.threshold_values:
            img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh)

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img = torch.round(img)

        return img.float()
Esempio n. 9
0
    def __call__(
        self,
        img: NdarrayOrTensor,
        argmax: Optional[bool] = None,
        to_onehot: Optional[int] = None,
        threshold: Optional[float] = None,
        rounding: Optional[str] = None
    ) -> NdarrayOrTensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].
        """

        img_t: torch.Tensor
        img_t, *_ = convert_data_type(img, torch.Tensor)  # type: ignore
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise ValueError("the number of classes for One-Hot must be an integer.")
            img_t = one_hot(
                img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float)
            )

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float))
        return img
Esempio n. 10
0
    def __call__(self, img, to_onehot: Optional[bool] = None, num_classes: Optional[int] = None):
        """
        Args:
            to_onehot: whether to convert the data to One-Hot format first.
                Defaults to ``self.to_onehot``.
            num_classes: the class number used to convert to One-Hot format if `to_onehot` is True.
                Defaults to ``self.num_classes``.
        """
        if to_onehot or self.to_onehot:
            if num_classes is None:
                num_classes = self.num_classes
            assert isinstance(num_classes, int), "must specify class number for One-Hot."
            img = one_hot(img, num_classes)
        n_classes = img.shape[1]
        outputs = list()
        for i in range(n_classes):
            outputs.append(img[:, i : i + 1])

        return outputs
Esempio n. 11
0
def do_binarization(
    input_data: torch.Tensor,
    bin_mode: str = "threshold",
    bin_threshold: Union[float, Sequence[float]] = 0.5,
) -> torch.Tensor:
    """
    Args:
        input_data: the input that to be binarized, in the shape [B] or [BN] or [BNHW] or [BNHWD].
        bin_mode: can be ``"threshold"`` or ``"mutually_exclusive"``, or a callable function.
            - ``"threshold"``, a single threshold or a sequence of thresholds should be set.
            - ``"mutually_exclusive"``, `input_data` will be converted by a combination of
            argmax and to_onehot.
        bin_threshold: the threshold to binarize the input data, can be a single value or a sequence of
            values that each one of the value represents a threshold for a class.

    Raises:
        AssertionError: when `bin_threshold` is a sequence and the input has the shape [B].
        AssertionError: when `bin_threshold` is a sequence but the length != the number of classes.
        AssertionError: when `bin_mode` is ``"mutually_exclusive"`` the input has the shape [B].
        AssertionError: when `bin_mode` is ``"mutually_exclusive"`` the input has the shape [B, 1].
    """
    input_ndim = input_data.ndimension()
    if bin_mode == "threshold":
        if isinstance(bin_threshold, Sequence):
            assert input_ndim > 1, "a sequence of thresholds are used for multi-class tasks."
            error_hint = "the length of the sequence should be the same as the number of classes."
            assert input_data.shape[1] == len(bin_threshold), "{}".format(
                error_hint)
            for cls_num in range(input_data.shape[1]):
                input_data[:, cls_num] = (input_data[:, cls_num] >
                                          bin_threshold[cls_num]).float()
        else:
            input_data = (input_data > bin_threshold).float()
    elif bin_mode == "mutually_exclusive":
        assert input_ndim > 1, "mutually_exclusive is used for multi-class tasks."
        n_classes = input_data.shape[1]
        assert n_classes > 1, "mutually_exclusive is used for multi-class tasks."
        input_data = torch.argmax(input_data, dim=1, keepdim=True)
        input_data = one_hot(input_data, num_classes=n_classes)
    return input_data
Esempio n. 12
0
 def test_consistency_with_cross_entropy_2d_onehot_label(self):
     # For gamma=0 the focal loss reduces to the cross entropy loss
     focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean")
     ce = nn.CrossEntropyLoss(reduction="mean")
     max_error = 0
     class_num = 10
     batch_size = 128
     for _ in range(100):
         # Create a random tensor of shape (batch_size, class_num, 8, 4)
         x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
         # Create a random batch of classes
         l = torch.randint(low=0, high=class_num, size=(batch_size, 1, 8, 4))
         if torch.cuda.is_available():
             x = x.cuda()
             l = l.cuda()
         output0 = focal_loss(x, one_hot(l, num_classes=class_num))
         output1 = ce(x, l[:, 0]) / class_num
         a = float(output0.cpu().detach())
         b = float(output1.cpu().detach())
         if abs(a - b) > max_error:
             max_error = abs(a - b)
     self.assertAlmostEqual(max_error, 0.0, places=3)
Esempio n. 13
0
 def test_consistency_with_cross_entropy_classification_01(self):
     # for gamma=0.1 the focal loss differs from the cross entropy loss
     focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean")
     ce = nn.BCEWithLogitsLoss(reduction="mean")
     max_error = 0
     class_num = 10
     batch_size = 128
     for _ in range(100):
         # Create a random scores tensor of shape (batch_size, class_num)
         x = torch.rand(batch_size, class_num, requires_grad=True)
         # Create a random batch of classes
         l = torch.randint(low=0, high=class_num, size=(batch_size, 1))
         l = l.long()
         if torch.cuda.is_available():
             x = x.cuda()
             l = l.cuda()
         output0 = focal_loss(x, l)
         output1 = ce(x, one_hot(l, num_classes=class_num))
         a = float(output0.cpu().detach())
         b = float(output1.cpu().detach())
         if abs(a - b) > max_error:
             max_error = abs(a - b)
     self.assertNotAlmostEqual(max_error, 0.0, places=3)
Esempio n. 14
0
    def forward(self, y_pred: torch.Tensor,
                y_true: torch.Tensor) -> torch.Tensor:
        n_pred_ch = y_pred.shape[1]

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                y_true = one_hot(y_true, num_classes=n_pred_ch)

        if y_true.shape != y_pred.shape:
            raise ValueError(
                f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})"
            )

        # clip the prediction to avoid NaN
        y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
        axis = list(range(2, len(y_pred.shape)))

        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = torch.sum(y_true * y_pred, dim=axis)
        fn = torch.sum(y_true * (1 - y_pred), dim=axis)
        fp = torch.sum((1 - y_true) * y_pred, dim=axis)
        dice_class = (tp +
                      self.epsilon) / (tp + self.delta * fn +
                                       (1 - self.delta) * fp + self.epsilon)

        # Calculate losses separately for each class, enhancing both classes
        back_dice = 1 - dice_class[:, 0]
        fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1],
                                                       -self.gamma)

        # Average class scores
        loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
        return loss
Esempio n. 15
0
    def __call__(
            self,
            img: NdarrayOrTensor,
            argmax: Optional[bool] = None,
            to_onehot: Optional[int] = None,
            threshold: Optional[float] = None,
            rounding: Optional[str] = None,
            n_classes: Optional[int] = None,  # deprecated
            num_classes: Optional[int] = None,  # deprecated
            logit_thresh: Optional[float] = None,  # deprecated
            threshold_values: Optional[bool] = None,  # deprecated
    ) -> NdarrayOrTensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``to_onehot`` instead.

        .. deprecated:: 0.7.0
            ``num_classes`` is deprecated, use ``to_onehot`` instead.
            ``logit_thresh`` is deprecated, use ``threshold`` instead.
            ``threshold_values`` is deprecated, use ``threshold`` instead.

        """
        if isinstance(to_onehot, bool):
            warnings.warn(
                "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead."
            )
            to_onehot = num_classes if to_onehot else None
        if isinstance(threshold, bool):
            warnings.warn(
                "`threshold_values=True/False` is deprecated, please use `threshold=value` instead."
            )
            threshold = logit_thresh if threshold else None
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor)
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=0, keepdim=True)

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise AssertionError(
                    "the number of classes for One-Hot must be an integer.")
            img_t = one_hot(img_t, num_classes=to_onehot, dim=0)

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
        return img
Esempio n. 16
0
    def test_convergence(self, loss_type, loss_args, forward_args):
        """
        The goal of this test is to assess if the gradient of the loss function
        is correct by testing if we can train a one layer neural network
        to segment one image.
        We verify that the loss is decreasing in almost all SGD steps.
        """
        learning_rate = 0.001
        max_iter = 40

        # define a simple 3d example
        target_seg = torch.tensor(
            [[
                # raw 0
                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
                # raw 1
                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
                # raw 2
                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
            ]],
            device=self.device,
        )
        target_seg = torch.unsqueeze(target_seg, dim=0)
        image = 12 * target_seg + 27
        image = image.float().to(self.device)
        num_classes = 2
        num_voxels = 3 * 4 * 4

        target_onehot = one_hot(target_seg, num_classes=num_classes)

        # define a one layer model
        class OnelayerNet(nn.Module):
            def __init__(self):
                super(OnelayerNet, self).__init__()
                self.layer_1 = nn.Linear(num_voxels, 200)
                self.acti = nn.ReLU()
                self.layer_2 = nn.Linear(200, num_voxels * num_classes)

            def forward(self, x):
                x = x.view(-1, num_voxels)
                x = self.layer_1(x)
                x = self.acti(x)
                x = self.layer_2(x)
                x = x.view(-1, num_classes, 3, 4, 4)
                return x

        # initialise the network
        net = OnelayerNet().to(self.device)

        # initialize the loss
        loss = loss_type(**loss_args)

        # initialize a SGD optimizer
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)

        loss_history = []
        init_output = None

        # train the network
        for iter_i in range(max_iter):
            # set the gradient to zero
            optimizer.zero_grad()

            # forward pass
            output = net(image)
            if init_output is None:
                init_output = torch.argmax(output, 1).detach().cpu().numpy()

            if loss_args["to_onehot_y"] is False:
                loss_val = loss(output, target_onehot, **forward_args)
            else:
                loss_val = loss(output, target_seg, **forward_args)

            if iter_i % 10 == 0:
                pred = torch.argmax(output, 1).detach().cpu().numpy()
                gt = target_seg.detach().cpu().numpy()[:, 0]
                print(
                    f"{loss_type.__name__} iter: {iter_i}, acc: {np.sum(pred == gt) / np.prod(pred.shape)}"
                )

            # backward pass
            loss_val.backward()
            optimizer.step()

            # stats
            loss_history.append(loss_val.item())

        pred = torch.argmax(output, 1).detach().cpu().numpy()
        target = target_seg.detach().cpu().numpy()[:, 0]
        # initial predictions are bad
        self.assertTrue(not np.allclose(init_output, target))
        # final predictions are good
        np.testing.assert_allclose(pred, target)
Esempio n. 17
0
    def forward(self,
                input: torch.Tensor,
                target: torch.Tensor,
                smooth: float = 1e-5) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].
            smooth: a small constant to avoid nan.
        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        intersection = torch.sum(target * input, dim=reduce_axis)

        if self.squared_pred:
            target = torch.pow(target, 2)
            input = torch.pow(input, 2)

        ground_o = torch.sum(target, dim=reduce_axis)
        pred_o = torch.sum(input, dim=reduce_axis)

        denominator = ground_o + pred_o

        if self.jaccard:
            denominator = 2.0 * (denominator - intersection)

        f: torch.Tensor = 1.0 - (2.0 * intersection + smooth) / (denominator +
                                                                 smooth)

        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            pass  # returns [N, n_classes] losses
        else:
            raise ValueError(
                f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
            )

        return f
Esempio n. 18
0
def compute_roc_auc(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    to_onehot_y: bool = False,
    softmax: bool = False,
    other_act: Optional[Callable] = None,
    average: Union[Average, str] = Average.MACRO,
):
    """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.

    Args:
        y_pred: input data to compute, typical classification model output.
            it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2].
        y: ground truth to compute ROC AUC metric, the first dim is batch.
            example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`).
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
        other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``.
            for example: `other_act = lambda x: torch.log_softmax(x)`.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    Raises:
        ValueError: When ``y_pred`` dimension is not one of [1, 2].
        ValueError: When ``y`` dimension is not one of [1, 2].
        ValueError: When ``softmax=True`` and ``other_act is not None``. Incompatible values.
        TypeError: When ``other_act`` is not an ``Optional[Callable]``.
        ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

    Note:
        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    """
    y_pred_ndim = y_pred.ndimension()
    y_ndim = y.ndimension()
    if y_pred_ndim not in (1, 2):
        raise ValueError(
            "Predictions should be of shape (batch_size, n_classes) or (batch_size, )."
        )
    if y_ndim not in (1, 2):
        raise ValueError(
            "Targets should be of shape (batch_size, n_classes) or (batch_size, )."
        )
    if y_pred_ndim == 2 and y_pred.shape[1] == 1:
        y_pred = y_pred.squeeze(dim=-1)
        y_pred_ndim = 1
    if y_ndim == 2 and y.shape[1] == 1:
        y = y.squeeze(dim=-1)

    if y_pred_ndim == 1:
        if to_onehot_y:
            warnings.warn(
                "y_pred has only one channel, to_onehot_y=True ignored.")
        if softmax:
            warnings.warn("y_pred has only one channel, softmax=True ignored.")
        return _calculate(y, y_pred)
    else:
        n_classes = y_pred.shape[1]
        if to_onehot_y:
            y = one_hot(y, n_classes)
        if softmax and other_act is not None:
            raise ValueError(
                "Incompatible values: softmax=True and other_act is not None.")
        if softmax:
            y_pred = y_pred.float().softmax(dim=1)
        if other_act is not None:
            if not callable(other_act):
                raise TypeError(
                    f"other_act must be None or callable but is {type(other_act).__name__}."
                )
            y_pred = other_act(y_pred)

        assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."

        average = Average(average)
        if average == Average.MICRO:
            return _calculate(y.flatten(), y_pred.flatten())
        else:
            y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
            auc_values = [
                _calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)
            ]
            if average == Average.NONE:
                return auc_values
            if average == Average.MACRO:
                return np.mean(auc_values)
            if average == Average.WEIGHTED:
                weights = [sum(y_) for y_ in y]
                return np.average(auc_values, weights=weights)
            raise ValueError(
                f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].'
            )
Esempio n. 19
0
def compute_meandice(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    include_background: bool = True,
    to_onehot_y: bool = False,
    mutually_exclusive: bool = False,
    sigmoid: bool = False,
    other_act: Optional[Callable] = None,
    logit_thresh: float = 0.5,
) -> torch.Tensor:
    """Computes Dice score metric from full size Tensor and collects average.

    Args:
        y_pred: input data to compute, typical segmentation model output.
            it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32].
        y: ground truth to compute mean dice metric, the first dim is batch.
            example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32].
            alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly.
        include_background: whether to skip Dice computation on the first channel of
            the predicted output. Defaults to True.
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using
            a combination of argmax and to_onehot.  Defaults to False.
        sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False.
        other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``.
            for example: `other_act = torch.tanh`.
        logit_thresh: the threshold value used to convert (for example, after sigmoid if `sigmoid=True`)
            `y_pred` into a binary matrix. Defaults to 0.5.

    Raises:
        ValueError: When ``sigmoid=True`` and ``other_act is not None``. Incompatible values.
        TypeError: When ``other_act`` is not an ``Optional[Callable]``.
        ValueError: When ``sigmoid=True`` and ``mutually_exclusive=True``. Incompatible values.

    Returns:
        Dice scores per batch and per class, (shape [batch_size, n_classes]).

    Note:
        This method provides two options to convert `y_pred` into a binary matrix
            (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``,
            (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh``
                (optionally with a ``sigmoid`` function before thresholding).

    """
    n_classes = y_pred.shape[1]
    n_len = len(y_pred.shape)
    if sigmoid and other_act is not None:
        raise ValueError(
            "Incompatible values: sigmoid=True and other_act is not None.")
    if sigmoid:
        y_pred = y_pred.float().sigmoid()

    if other_act is not None:
        if not callable(other_act):
            raise TypeError(
                f"other_act must be None or callable but is {type(other_act).__name__}."
            )
        y_pred = other_act(y_pred)

    if n_classes == 1:
        if mutually_exclusive:
            warnings.warn(
                "y_pred has only one class, mutually_exclusive=True ignored.")
        if to_onehot_y:
            warnings.warn(
                "y_pred has only one channel, to_onehot_y=True ignored.")
        if not include_background:
            warnings.warn(
                "y_pred has only one channel, include_background=False ignored."
            )
        # make both y and y_pred binary
        y_pred = (y_pred >= logit_thresh).float()
        y = (y > 0).float()
    else:  # multi-channel y_pred
        # make both y and y_pred binary
        if mutually_exclusive:
            if sigmoid:
                raise ValueError(
                    "Incompatible values: sigmoid=True and mutually_exclusive=True."
                )
            y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
            y_pred = one_hot(y_pred, num_classes=n_classes)
        else:
            y_pred = (y_pred >= logit_thresh).float()
        if to_onehot_y:
            y = one_hot(y, num_classes=n_classes)

    if not include_background:
        y = y[:, 1:] if y.shape[1] > 1 else y
        y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

    assert y.shape == y_pred.shape, "Ground truth one-hot has differing shape (%r) from source (%r)" % (
        y.shape,
        y_pred.shape,
    )
    y = y.float()
    y_pred = y_pred.float()

    # reducing only spatial dimensions (not batch nor channels)
    reduce_axis = list(range(2, n_len))
    intersection = torch.sum(y * y_pred, dim=reduce_axis)

    y_o = torch.sum(y, reduce_axis)
    y_pred_o = torch.sum(y_pred, dim=reduce_axis)
    denominator = y_o + y_pred_o

    f = torch.where(y_o > 0, (2.0 * intersection) / denominator,
                    torch.tensor(float("nan"), device=y_o.device))
    return f  # returns array of Dice shape: [batch, n_classes]
Esempio n. 20
0
    def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5):
        """
        Args:
            input (tensor): the shape should be BNH[WD].
            target (tensor): the shape should be BNH[WD].
            smooth: a small constant to avoid nan.

        Raises:
            ValueError: reduction={self.reduction} is invalid.

        """
        if self.sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if n_pred_ch == 1:
            if self.softmax:
                warnings.warn("single channel prediction, `softmax=True` ignored.")
            if self.to_onehot_y:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            if not self.include_background:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
        else:
            if self.softmax:
                input = torch.softmax(input, 1)
            if self.to_onehot_y:
                target = one_hot(target, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]
        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        intersection = torch.sum(target * input, reduce_axis)

        ground_o = torch.sum(target, reduce_axis)
        pred_o = torch.sum(input, reduce_axis)

        denominator = ground_o + pred_o

        w = self.w_func(ground_o.float())
        for b in w:
            infs = torch.isinf(b)
            b[infs] = 0.0
            b[infs] = torch.max(b)

        f = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth)

        if self.reduction == LossReduction.MEAN:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE:
            pass  # returns [N, n_classes] losses
        else:
            raise ValueError(f"reduction={self.reduction} is invalid.")

        return f
Esempio n. 21
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(
                f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
            )

        p0 = input
        p1 = 1 - p0
        g0 = target
        g1 = 1 - g0

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
        if self.batch:
            # reducing spatial dimensions and batch
            reduce_axis = [0] + reduce_axis

        tp = torch.sum(p0 * g0, reduce_axis)
        fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
        fn = self.beta * torch.sum(p1 * g0, reduce_axis)
        numerator = tp + self.smooth_nr
        denominator = tp + fp + fn + self.smooth_dr

        score: torch.Tensor = 1.0 - numerator / denominator

        if self.reduction == LossReduction.SUM.value:
            return torch.sum(score)  # sum over the batch and channel dims
        if self.reduction == LossReduction.NONE.value:
            return score  # returns [N, n_classes] losses
        if self.reduction == LossReduction.MEAN.value:
            return torch.mean(score)
        raise ValueError(
            f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
        )
Esempio n. 22
0
    def forward(self,
                input,
                target: torch.Tensor,
                smooth: float = 1e-5) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].
            smooth: a small constant to avoid nan.

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """

        x, att_maps = input

        loss_function_single_channel = Dice(to_onehot_y=False, softmax=False)

        total_att_loss = 0

        if self.supervised_attention:
            L = len(att_maps)
            att_losses = []

            G_l = target
            for level in range(L):
                # A[level] are the attention maps as they arrive here
                # G[level] are downsampled ground truth maps, they are converted to one-hot inside the loss-function
                att_loss = loss_function_single_channel(
                    att_maps[L - level - 1], G_l)
                att_losses.append(att_loss)
                total_att_loss = total_att_loss + 1 / L * att_loss

                if level < L - 1:
                    shape_curr_att_map = att_maps[L - level - 1].shape
                    shape_next_att_map = att_maps[L - level - 2].shape

                    # assert that shape of current attention map is multiple of next att map in all dimensions
                    assert all([
                        x % y == 0
                        for x, y in zip(shape_curr_att_map, shape_next_att_map)
                    ])
                    shape_ratio = [
                        x // y
                        for x, y in zip(shape_curr_att_map, shape_next_att_map)
                    ]

                    kernel_size_and_stride = shape_ratio[2:5]
                    G_l = torch.nn.MaxPool3d(
                        kernel_size=kernel_size_and_stride,
                        stride=kernel_size_and_stride)(G_l)
        hardness_weight = None
        if self.hardness_weighting:
            hardness_lambda = 0.6
            hardness_weight = hardness_lambda * abs(
                torch.softmax(x, dim=1) -
                one_hot(target, num_classes=x.shape[1])) + (1.0 -
                                                            hardness_lambda)
            # img = hardness_weight.cpu().detach().numpy()
            # x_ = torch.softmax(x, dim=1).cpu().detach().numpy()
            # target_ = one_hot(target, num_classes=x.shape[1]).cpu().detach().numpy()
            # import matplotlib.pyplot as plt
            # fig, axs = plt.subplots(1, 3)
            # axs[0].imshow(x_[1, 1, :, :, 20], cmap='gray')
            # axs[1].imshow(target_[1, 1, :, :, 20])
            # axs[2].imshow(img[1, 1, :, :, 20])
            # plt.show()
            # pass

        loss_function_multi_channel = Dice(to_onehot_y=True,
                                           softmax=True,
                                           hardness_weight=hardness_weight)
        pred_loss = loss_function_multi_channel(x, target)
        return total_att_loss + pred_loss
Esempio n. 23
0
    def forward(self,
                input: torch.Tensor,
                target: torch.Tensor,
                smooth: float = 1e-5):
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].
            smooth: a small constant to avoid nan.

        Raises:
            ValueError: reduction={self.reduction} is invalid.

        """
        if self.sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if n_pred_ch == 1:
            if self.softmax:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            if self.to_onehot_y:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            if not self.include_background:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
        else:
            if self.softmax:
                input = torch.softmax(input, 1)
            if self.to_onehot_y:
                target = one_hot(target, n_pred_ch)
            if not self.include_background:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]
        assert (
            target.shape == input.shape
        ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

        p0 = input
        p1 = 1 - p0
        g0 = target
        g1 = 1 - g0

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))

        tp = torch.sum(p0 * g0, reduce_axis)
        fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
        fn = self.beta * torch.sum(p1 * g0, reduce_axis)

        numerator = tp + smooth
        denominator = tp + fp + fn + smooth

        score = 1.0 - numerator / denominator

        if self.reduction == LossReduction.SUM:
            return score.sum()  # sum over the batch and channel dims
        if self.reduction == LossReduction.NONE:
            return score  # returns [N, n_classes] losses
        if self.reduction == LossReduction.MEAN:
            return score.mean()
        raise ValueError(f"reduction={self.reduction} is invalid.")
Esempio n. 24
0
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis = list(range(2, len(input.shape)))
        if self.batch:
            # reducing spatial dimensions and batch
            reduce_axis = [0] + reduce_axis

        intersection = torch.sum(target * input, dim=reduce_axis)

        ### uncoment lines below to enable label weights
        # if self.label_weights is not None:  # add wights to labels
        #     bs=intersection.shape[0]
        #     w = torch.tensor(self.label_weights, dtype=torch.float32,device=torch.device('cuda:0'))
        #     w= w.repeat(bs, 1) ## change size to [BS, Num of classes ]
        #     intersection = w* intersection

        if self.squared_pred:
            target = torch.pow(target, 2)
            input = torch.pow(input, 2)

        ground_o = torch.sum(target, dim=reduce_axis)
        pred_o = torch.sum(input, dim=reduce_axis)

        denominator = ground_o + pred_o

        if self.jaccard:
            denominator = 2.0 * (denominator - intersection)

        f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)

        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            pass  # returns [N, n_classes] losses
        else:
            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

        return f
Esempio n. 25
0
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
        if self.batch:
            reduce_axis = [0] + reduce_axis
        intersection = torch.sum(target * input, reduce_axis)

        ground_o = torch.sum(target, reduce_axis)
        pred_o = torch.sum(input, reduce_axis)

        denominator = ground_o + pred_o

        w = self.w_func(ground_o.float())
        for b in w:
            infs = torch.isinf(b)
            b[infs] = 0.0
            b[infs] = torch.max(b)

        f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(0 if self.batch else 1) + self.smooth_nr) / (
            (denominator * w).sum(0 if self.batch else 1) + self.smooth_dr
        )

        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            pass  # returns [N, n_classes] losses
        else:
            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

        return f
Esempio n. 26
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD], where N is the number of classes.
            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

        Raises:
            AssertionError: When input and target (after one hot transform if set)
                have different shapes.
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        Example:
            >>> from monai.losses.dice import *  # NOQA
            >>> import torch
            >>> from monai.losses.dice import DiceLoss
            >>> B, C, H, W = 7, 5, 3, 2
            >>> input = torch.rand(B, C, H, W)
            >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
            >>> target = one_hot(target_idx[:, None, ...], num_classes=C)
            >>> self = DiceLoss(reduction='none')
            >>> loss = self(input, target)
            >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
        """
        if self.sigmoid:
            input = torch.sigmoid(input)

        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(
                f"ground truth has different shape ({target.shape}) from input ({input.shape})"
            )

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
        if self.batch:
            # reducing spatial dimensions and batch
            reduce_axis = [0] + reduce_axis

        intersection = torch.sum(target * input, dim=reduce_axis)

        if self.squared_pred:
            target = torch.pow(target, 2)
            input = torch.pow(input, 2)

        ground_o = torch.sum(target, dim=reduce_axis)
        pred_o = torch.sum(input, dim=reduce_axis)

        denominator = ground_o + pred_o

        if self.jaccard:
            denominator = 2.0 * (denominator - intersection)

        f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (
            denominator + self.smooth_dr)

        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            # If we are not computing voxelwise loss components at least
            # make sure a none reduction maintains a broadcastable shape
            broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
            f = f.view(broadcast_shape)
        else:
            raise ValueError(
                f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
            )

        return f
Esempio n. 27
0
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        """
        if self.sigmoid:
            input = torch.sigmoid(input)
        n_pred_ch = input.shape[1]
        if self.softmax:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `softmax=True` ignored.")
            else:
                input = torch.softmax(input, 1)

        if self.other_act is not None:
            input = self.other_act(input)

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn("single channel prediction, `include_background=False` ignored.")
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

        # reducing only spatial dimensions (not batch nor channels)
        reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
        if self.batch:
            reduce_axis = [0] + reduce_axis
        intersection = torch.sum(target * input, reduce_axis)

        ground_o = torch.sum(target, reduce_axis)
        pred_o = torch.sum(input, reduce_axis)

        denominator = ground_o + pred_o

        w = self.w_func(ground_o.float())
        infs = torch.isinf(w)
        if self.batch:
            w[infs] = 0.0
            w = w + infs * torch.max(w)
        else:
            w[infs] = 0.0
            max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
            w = w + infs * max_values

        final_reduce_dim = 0 if self.batch else 1
        numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
        denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
        f: torch.Tensor = 1.0 - (numer / denom)

        if self.reduction == LossReduction.MEAN.value:
            f = torch.mean(f)  # the batch and channel average
        elif self.reduction == LossReduction.SUM.value:
            f = torch.sum(f)  # sum over the batch and channel dims
        elif self.reduction == LossReduction.NONE.value:
            # If we are not computing voxelwise loss components at least
            # make sure a none reduction maintains a broadcastable shape
            broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
            f = f.view(broadcast_shape)
        else:
            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

        return f
Esempio n. 28
0
    def forward(self, input: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: the shape should be BNH[WD], where N is the number of classes.
                The input should be the original logits since it will be transferred by
                `F.log_softmax` in the forward function.
            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

        Raises:
            AssertionError: When input and target (after one hot transform if setted)
                have different shapes.
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``self.weight`` is a sequence and the length is not equal to the
                number of classes.
            ValueError: When ``self.weight`` is/contains a value that is less than 0.

        """
        n_pred_ch = input.shape[1]

        if self.to_onehot_y:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `to_onehot_y=True` ignored.")
            else:
                target = one_hot(target, num_classes=n_pred_ch)

        if not self.include_background:
            if n_pred_ch == 1:
                warnings.warn(
                    "single channel prediction, `include_background=False` ignored."
                )
            else:
                # if skipping background, removing first channel
                target = target[:, 1:]
                input = input[:, 1:]

        if target.shape != input.shape:
            raise AssertionError(
                f"ground truth has different shape ({target.shape}) from input ({input.shape})"
            )

        i = input
        t = target

        # Change the shape of input and target to B x N x num_voxels.
        i = i.view(i.size(0), i.size(1), -1)
        t = t.view(t.size(0), t.size(1), -1)

        # Compute the log proba.
        logpt = F.log_softmax(i, dim=1)
        # Get the proba
        pt = torch.exp(logpt)  # B,H*W or B,N,H*W

        if self.weight is not None:
            class_weight: Optional[torch.Tensor] = None
            if isinstance(self.weight, (float, int)):
                class_weight = torch.as_tensor([self.weight] * i.size(1))
            else:
                class_weight = torch.as_tensor(self.weight)
                if class_weight.size(0) != i.size(1):
                    raise ValueError(
                        "the length of the weight sequence should be the same as the number of classes. "
                        +
                        "If `include_background=False`, the number should not include class 0."
                    )
            if class_weight.min() < 0:
                raise ValueError(
                    "the value/values of weights should be no less than 0.")
            class_weight = class_weight.to(i)
            # Convert the weight to a map in which each voxel
            # has the weight associated with the ground-truth label
            # associated with this voxel in target.
            at = class_weight[None, :, None]  # N => 1,N,1
            at = at.expand((t.size(0), -1, t.size(2)))  # 1,N,1 => B,N,H*W
            # Multiply the log proba by their weights.
            logpt = logpt * at

        # Compute the loss mini-batch.
        weight = torch.pow(-pt + 1.0, self.gamma)
        loss = torch.mean(-weight * t * logpt, dim=-1)
        if self.reduction == LossReduction.SUM.value:
            return loss.sum()
        if self.reduction == LossReduction.NONE.value:
            return loss
        if self.reduction == LossReduction.MEAN.value:
            return loss.mean()
        raise ValueError(
            f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].'
        )
Esempio n. 29
0
def compute_confusion_metric(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    to_onehot_y: bool = False,
    activation: Optional[Union[str, Callable]] = None,
    bin_mode: Optional[str] = "threshold",
    bin_threshold: Union[float, Sequence[float]] = 0.5,
    metric_name: str = "hit_rate",
    average: Union[Average, str] = Average.MACRO,
    zero_division: int = 0,
) -> Union[np.ndarray, List[float], float]:
    """
    Compute confusion matrix related metrics. This function supports to calculate all metrics
    mentioned in: `Confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_.
    Before calculating, an activation function and/or a binarization manipulation can be employed
    to pre-process the original inputs. Zero division is handled by replacing the result into a
    single value. Referring to:
    `sklearn.metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_.

    Args:
        y_pred: predictions. As for classification tasks,
            `y_pred` should has the shape [B] or [BN]. As for segmentation tasks,
            the shape should be [BNHW] or [BNHWD].
        y: ground truth, the first dim is batch.
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        activation: [``"sigmoid"``, ``"softmax"``]
            Activation method, if specified, an activation function will be employed for `y_pred`.
            Defaults to None.
            The parameter can also be a callable function, for example:
            ``activation = lambda x: torch.log_softmax(x)``.
        bin_mode: [``"threshold"``, ``"mutually_exclusive"``]
            Binarization method, if specified, a binarization manipulation will be employed
            for `y_pred`.

            - ``"threshold"``, a single threshold or a sequence of thresholds should be set.
            - ``"mutually_exclusive"``, `y_pred` will be converted by a combination of `argmax` and `to_onehot`.
        bin_threshold: the threshold for binarization, can be a single value or a sequence of
            values that each one of the value represents a threshold for a class.
        metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``,
            ``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``,
            ``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``,
            ``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``,
            ``"informedness"``, ``"markedness"``]
            Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
            and you can also input those names instead.
        average: [``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``]
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.
        zero_division: the value to return when there is a zero division, for example, when all
            predictions and labels are negative. Defaults to 0.
    Raises:
        AssertionError: when data shapes of `y_pred` and `y` do not match.
        AssertionError: when specify activation function and ``mutually_exclusive`` mode at the same time.
    """

    y_pred_ndim, y_ndim = y_pred.ndimension(), y.ndimension()
    # one-hot for ground truth
    if to_onehot_y:
        if y_pred_ndim == 1:
            warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.")
        else:
            n_classes = y_pred.shape[1]
            y = one_hot(y, num_classes=n_classes)
    # check shape
    assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."
    # activation for predictions
    if activation is not None:
        assert bin_mode != "mutually_exclusive", "activation is unnecessary for mutually exclusive classes."
        y_pred = do_activation(y_pred, activation=activation)
    # binarization for predictions
    if bin_mode is not None:
        y_pred = do_binarization(y_pred, bin_mode=bin_mode, bin_threshold=bin_threshold)
    # get confusion matrix elements
    con_list = cal_confusion_matrix_elements(y_pred, y)
    # get simplified metric name
    metric_name = check_metric_name_and_unify(metric_name)
    result = do_calculate_metric(con_list, metric_name, average=average, zero_division=zero_division)
    return result
Esempio n. 30
0
def dice_loss(input: tensor,
              target: tensor,
              include_background: bool = True,
              softmax: bool = False,
              to_onehot: bool = True,
              squared_pred: bool = False,
              reduction: Union[LossReduction, str] = LossReduction.MEAN,
              smooth: float = 1e-5):
    """
    loss function, from
    Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation, 3DV, 2016.

    Args:
        input: predict tensor,the shape should be BNH[WD].
        target: target tensor, the shape should be BNH[WD].
        include_background:
        softmax: if True, apply a softmax function to the prediction.
        to_onehot: whether to convert `target` into the one-hot format. Defaults to False.
        squared_pred: use squared versions of targets and predictions in the denominator or not.
        reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.
                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
        smooth: a small constant to avoid nan.
    """

    n_pred_ch = input.shape[1]
    if softmax:
        input = torch.softmax(input, 1)

    if to_onehot:
        if n_pred_ch == 1:
            warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
        else:
            # the F.one_hot can not use here, because it would return BNH[WD]C (C is the class
            target = one_hot(target.to(torch.int64), num_classes=n_pred_ch)

    if not include_background:
        if n_pred_ch == 1:
            warnings.warn("single channel prediction, `include_background=False` ignored.")
        else:
            # if skipping background, removing first channel
            target = target[:, 1:]
            input = input[:, 1:]

    assert (
            target.shape == input.shape
    ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"

    # reducing only spatial dimensions (not batch nor channels)
    reduce_axis = list(range(2, len(input.shape)))
    intersection = torch.sum(target * input, dim=reduce_axis)

    if squared_pred:
        target = torch.pow(target, 2)
        input = torch.pow(input, 2)

    ground_o = torch.sum(target, dim=reduce_axis)
    pred_o = torch.sum(input, dim=reduce_axis)

    denominator = ground_o + pred_o

    f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)

    reduction = LossReduction(reduction).value
    if reduction == LossReduction.MEAN.value:
        f = torch.mean(f)  # the batch and channel average
    elif reduction == LossReduction.SUM.value:
        f = torch.sum(f)  # sum over the batch and channel dims
    elif reduction == LossReduction.NONE.value:
        pass  # returns [N, n_classes] losses
    else:
        raise ValueError(f'Unsupported reduction: {reduction}, available options are ["mean", "sum", "none"].')

    return f