Beispiel #1
0
 def __init__(
     self,
     include_background: bool = True,
     to_onehot_y: bool = False,
     sigmoid: bool = False,
     softmax: bool = False,
     other_act: Optional[Callable] = None,
     w_type: Union[Weight, str] = Weight.SQUARE,
     reduction: Union[LossReduction, str] = LossReduction.MEAN,
     smooth_nr: float = 1e-5,
     smooth_dr: float = 1e-5,
     batch: bool = False,
     gamma: float = 2.0,
     focal_weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None,
     lambda_gdl: float = 1.0,
     lambda_focal: float = 1.0,
 ) -> None:
     super().__init__()
     self.generalized_dice = GeneralizedDiceLoss(
         include_background=include_background,
         to_onehot_y=to_onehot_y,
         sigmoid=sigmoid,
         softmax=softmax,
         other_act=other_act,
         w_type=w_type,
         reduction=reduction,
         smooth_nr=smooth_nr,
         smooth_dr=smooth_dr,
         batch=batch,
     )
     self.focal = FocalLoss(
         include_background=include_background,
         to_onehot_y=to_onehot_y,
         gamma=gamma,
         weight=focal_weight,
         reduction=reduction,
     )
     if lambda_gdl < 0.0:
         raise ValueError("lambda_gdl should be no less than 0.0.")
     if lambda_focal < 0.0:
         raise ValueError("lambda_focal should be no less than 0.0.")
     self.lambda_gdl = lambda_gdl
     self.lambda_focal = lambda_focal
Beispiel #2
0
    def __init__(
        self,
        include_background: bool = True,
        to_onehot_y: bool = False,
        sigmoid: bool = False,
        softmax: bool = False,
        other_act: Optional[Callable] = None,
        squared_pred: bool = False,
        jaccard: bool = False,
        reduction: str = "mean",
        smooth_nr: float = 1e-5,
        smooth_dr: float = 1e-5,
        batch: bool = False,
        gamma: float = 2.0,
        focal_weight: Optional[Union[Sequence[float], float, int,
                                     torch.Tensor]] = None,
        lambda_dice: float = 1.0,
        lambda_focal: float = 1.0,
    ) -> None:
        """
        Args:
            ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
            ``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
            and other parameters are only used for dice loss.
            include_background: if False channel index 0 (background category) is excluded from the calculation.
            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
            sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
                don't need to specify activation function for `FocalLoss`.
            softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
                don't need to specify activation function for `FocalLoss`.
            other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
                other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
                only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`.
            squared_pred: use squared versions of targets and predictions in the denominator or not.
            jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.

            smooth_nr: a small constant added to the numerator to avoid zero.
            smooth_dr: a small constant added to the denominator to avoid nan.
            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
                Defaults to False, a Dice loss value is computed independently from each item in the batch
                before any `reduction`.
            gamma: value of the exponent gamma in the definition of the Focal loss.
            focal_weight: weights to apply to the voxels of each class. If None no weights are applied.
                The input can be a single value (same weight for all classes), a sequence of values (the length
                of the sequence should be the same as the number of classes).
            lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
                Defaults to 1.0.
            lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
                Defaults to 1.0.

        """
        super().__init__()
        self.dice = DiceLoss(
            include_background=include_background,
            to_onehot_y=to_onehot_y,
            sigmoid=sigmoid,
            softmax=softmax,
            other_act=other_act,
            squared_pred=squared_pred,
            jaccard=jaccard,
            reduction=reduction,
            smooth_nr=smooth_nr,
            smooth_dr=smooth_dr,
            batch=batch,
        )
        self.focal = FocalLoss(
            include_background=include_background,
            to_onehot_y=to_onehot_y,
            gamma=gamma,
            weight=focal_weight,
            reduction=reduction,
        )
        if lambda_dice < 0.0:
            raise ValueError("lambda_dice should be no less than 0.0.")
        if lambda_focal < 0.0:
            raise ValueError("lambda_focal should be no less than 0.0.")
        self.lambda_dice = lambda_dice
        self.lambda_focal = lambda_focal
Beispiel #3
0
 def loss_fx(self):
     return FocalLoss(reduction=self.reduction)