def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Focal loss. See :class:`~kornia.losses.FocalLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) >= 2: raise ValueError( "Invalid input shape, we expect BxCx*. Got: {}".format( input.shape)) if input.size(0) != target.size(0): raise ValueError( 'Expected input batch_size ({}) to match target batch_size ({}).'. format(input.size(0), target.size(0))) n = input.size(0) out_size = (n, ) + input.size()[2:] if target.size()[1:] != input.size()[2:]: raise ValueError('Expected target size {}, got {}'.format( out_size, target.size())) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}".format( input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) + eps # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual focal loss weight = torch.pow(-input_soft + 1., gamma) focal = -alpha * weight * torch.log(input_soft) loss_tmp = torch.sum(target_one_hot * focal, dim=1) if reduction == 'none': loss = loss_tmp elif reduction == 'mean': loss = torch.mean(loss_tmp) elif reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError( "Invalid reduction mode: {}".format(reduction)) return loss
def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}".format( input.device, target.device)) # compute softmax over the classes axis input_soft = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) cardinality = torch.sum(input_soft + target_one_hot, dims) dice_score = 2. * intersection / (cardinality + self.eps) return torch.mean(torch.tensor(1.) - dice_score)
def tversky_loss(input: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Tversky loss. See :class:`~kornia.losses.TverskyLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {} and {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {} and {}". format(input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) fps = torch.sum(input_soft * (-target_one_hot + 1.), dims) fns = torch.sum((-input_soft + 1.) * target_one_hot, dims) numerator = intersection denominator = intersection + alpha * fps + beta * fns tversky_loss = numerator / (denominator + eps) return torch.mean(-tversky_loss + 1.)
def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: r"""Function that computes Sørensen-Dice Coefficient loss. See :class:`~kornia.losses.DiceLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}".format( input.device, target.device)) # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) cardinality = torch.sum(input_soft + target_one_hot, dims) dice_score = 2. * intersection / (cardinality + eps) return torch.mean(-dice_score + 1.)
def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}".format( input.device, target.device)) # compute softmax over the classes axis input_soft = F.softmax(input, dim=1) + self.eps # create the labels one hot tensor target_one_hot = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual focal loss weight = torch.pow( torch.tensor(1.) - input_soft, self.gamma.to(input.dtype)) focal = -self.alpha * weight * torch.log(input_soft) loss_tmp = torch.sum(target_one_hot * focal, dim=1) if self.reduction == 'none': loss = loss_tmp elif self.reduction == 'mean': loss = torch.mean(loss_tmp) elif self.reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError("Invalid reduction mode: {}".format( self.reduction)) return loss
def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 4: raise ValueError( "Invalid input shape, we expect BxNxHxW. Got: {}".format( input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError( "input and target shapes must be the same. Got: {}".format( input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}".format( input.device, target.device)) # compute softmax over the classes axis input_soft = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) fps = torch.sum(input_soft * (torch.tensor(1.) - target_one_hot), dims) fns = torch.sum((torch.tensor(1.) - input_soft) * target_one_hot, dims) numerator = intersection denominator = intersection + self.alpha * fps + self.beta * fns tversky_loss = numerator / (denominator + self.eps) return torch.mean(torch.tensor(1.) - tversky_loss)