def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Focal loss. See :class:`~kornia.losses.FocalLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) >= 2: raise ValueError( "Invalid input shape, we expect BxCx*. Got: {}".format( input.shape)) if input.size(0) != target.size(0): raise ValueError( 'Expected input batch_size ({}) to match target batch_size ({}).'. format(input.size(0), target.size(0))) n = input.size(0) out_size = (n, ) + input.size()[2:] if target.size()[1:] != input.size()[2:]: raise ValueError('Expected target size {}, got {}'.format( out_size, target.size())) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) + eps # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual focal loss weight = torch.pow(-input_soft + 1., gamma) focal = -alpha * weight * torch.log(input_soft) loss_tmp = torch.sum(target_one_hot * focal, dim=1) if reduction == 'none': loss = loss_tmp elif reduction == 'mean': loss = torch.mean(loss_tmp) elif reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError( "Invalid reduction mode: {}".format(reduction)) return loss
def tversky_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Tversky loss. See :class:`~kornia.losses.TverskyLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {} and {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) fps = torch.sum(input_soft * (-target_one_hot + 1.), dims) fns = torch.sum((-input_soft + 1.) * target_one_hot, dims) numerator = intersection denominator = intersection + alpha * fps + beta * fns tversky_loss = numerator / (denominator + eps) return torch.mean(-tversky_loss + 1.)
def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Sørensen-Dice Coefficient loss. See :class:`~kornia.losses.DiceLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {} and {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) cardinality = torch.sum(input_soft + target_one_hot, dims) dice_score = 2. * intersection / (cardinality + eps) return torch.mean(-dice_score + 1.)
def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-8) -> torch.Tensor: r"""Criterion that computes Focal loss. According to :cite:`lin2018focal`, the Focal loss is computed as follows: .. math:: \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) Where: - :math:`p_t` is the model's estimated probability for each class. Args: input (torch.Tensor): logits tensor with shape :math:`(N, C, *)` where C = number of classes. target (torch.Tensor): labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. gamma (float, optional): Focusing parameter :math:`\gamma >= 0`. Default 2. reduction (str, optional): Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘none’. eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8. Return: torch.Tensor: the computed loss. Example: >>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean') >>> output.backward() """ if not isinstance(input, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) >= 2: raise ValueError( "Invalid input shape, we expect BxCx*. Got: {}".format( input.shape)) if input.size(0) != target.size(0): raise ValueError( 'Expected input batch_size ({}) to match target batch_size ({}).'. format(input.size(0), target.size(0))) n = input.size(0) out_size = (n, ) + input.size()[2:] if target.size()[1:] != input.size()[2:]: raise ValueError('Expected target size {}, got {}'.format( out_size, target.size())) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) + eps # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual focal loss weight = torch.pow(-input_soft + 1., gamma) focal = -alpha * weight * torch.log(input_soft) loss_tmp = torch.sum(target_one_hot * focal, dim=1) if reduction == 'none': loss = loss_tmp elif reduction == 'mean': loss = torch.mean(loss_tmp) elif reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError( "Invalid reduction mode: {}".format(reduction)) return loss
def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: r"""Criterion that computes Sørensen-Dice Coefficient loss. According to [1], we compute the Sørensen-Dice Coefficient as follows: .. math:: \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} Where: - :math:`X` expects to be the scores of each class. - :math:`Y` expects to be the one-hot tensor with the class labels. the loss, is finally computed as: .. math:: \text{loss}(x, class) = 1 - \text{Dice}(x, class) Reference: [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient Args: input (torch.Tensor): logits tensor with shape :math:`(N, C, H, W)` where C = number of classes. labels (torch.Tensor): labels tensor with shape :math:`(N, H, W)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8. Return: torch.Tensor: the computed loss. Example: >>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = dice_loss(input, target) >>> output.backward() """ if not isinstance(input, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {} and {}".format( input.shape, target.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) cardinality = torch.sum(input_soft + target_one_hot, dims) dice_score = 2.0 * intersection / (cardinality + eps) return torch.mean(-dice_score + 1.0)
def focal_loss( input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: Optional[float] = None, ) -> torch.Tensor: r"""Criterion that computes Focal loss. According to :cite:`lin2018focal`, the Focal loss is computed as follows: .. math:: \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) Where: - :math:`p_t` is the model's estimated probability for each class. Args: input: logits tensor with shape :math:`(N, C, *)` where C = number of classes. target: labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. alpha: Weighting factor :math:`\alpha \in [0, 1]`. gamma: Focusing parameter :math:`\gamma >= 0`. reduction: Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used. Return: the computed loss. Example: >>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean') >>> output.backward() """ if eps is not None and not torch.jit.is_scripting(): warnings.warn( "`focal_loss` has been reworked for improved numerical stability " "and the `eps` argument is no longer necessary", DeprecationWarning, stacklevel=2, ) if not isinstance(input, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) >= 2: raise ValueError( f"Invalid input shape, we expect BxCx*. Got: {input.shape}") if input.size(0) != target.size(0): raise ValueError( f'Expected input batch_size ({input.size(0)}) to match target batch_size ({target.size(0)}).' ) n = input.size(0) out_size = (n, ) + input.size()[2:] if target.size()[1:] != input.size()[2:]: raise ValueError( f'Expected target size {out_size}, got {target.size()}') if not input.device == target.device: raise ValueError( f"input and target must be in the same device. Got: {input.device} and {target.device}" ) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) log_input_soft: torch.Tensor = F.log_softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual focal loss weight = torch.pow(-input_soft + 1.0, gamma) focal = -alpha * weight * log_input_soft loss_tmp = torch.einsum('bc...,bc...->b...', (target_one_hot, focal)) if reduction == 'none': loss = loss_tmp elif reduction == 'mean': loss = torch.mean(loss_tmp) elif reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError(f"Invalid reduction mode: {reduction}") return loss
def tversky_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, eps: float = 1e-8) -> torch.Tensor: r"""Criterion that computes Tversky Coefficient loss. According to :cite:`salehi2017tversky`, we compute the Tversky Coefficient as follows: .. math:: \text{S}(P, G, \alpha; \beta) = \frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|} Where: - :math:`P` and :math:`G` are the predicted and ground truth binary labels. - :math:`\alpha` and :math:`\beta` control the magnitude of the penalties for FPs and FNs, respectively. Note: - :math:`\alpha = \beta = 0.5` => dice coeff - :math:`\alpha = \beta = 1` => tanimoto coeff - :math:`\alpha + \beta = 1` => F beta coeff Args: input (torch.Tensor): logits tensor with shape :math:`(N, C, H, W)` where C = number of classes. target (torch.Tensor): labels tensor with shape :math:`(N, H, W)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. alpha (float): the first coefficient in the denominator. beta (float): the second coefficient in the denominator. eps (float, optional): scalar for numerical stability. Default: 1e-8. Return: torch.Tensor: the computed loss. Example: >>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = tversky_loss(input, target, alpha=0.5, beta=0.5) >>> output.backward() """ if not isinstance(input, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {} and {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) fps = torch.sum(input_soft * (-target_one_hot + 1.0), dims) fns = torch.sum((-input_soft + 1.0) * target_one_hot, dims) numerator = intersection denominator = intersection + alpha * fps + beta * fns tversky_loss = numerator / (denominator + eps) return torch.mean(-tversky_loss + 1.0)