def sigmoid_focal_loss(pred, target, weight=None, gamma=2.0, alpha=0.25, reduction='mean', avg_factor=None): r"""A warpper of cuda version `Focal Loss <https://arxiv.org/abs/1708.02002>`_. Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes. target (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. gamma (float, optional): The gamma for calculating the modulating factor. Defaults to 2.0. alpha (float, optional): A balanced form for Focal Loss. Defaults to 0.25. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ # Function.apply does not accept keyword arguments, so the decorator # "weighted_loss" is not applicable loss = _sigmoid_focal_loss(pred.contiguous(), target, gamma, alpha, None, 'none') if weight is not None: if weight.shape != loss.shape: if weight.size(0) == loss.size(0): # For most cases, weight is of shape (num_priors, ), # which means it does not have the second axis num_class weight = weight.view(-1, 1) else: # Sometimes, weight per anchor per class is also needed. e.g. # in FSAF. But it may be flattened of shape # (num_priors x num_class, ), while loss is still of shape # (num_priors, num_class). assert weight.numel() == loss.numel() weight = weight.view(loss.size(0), -1) assert weight.ndim == loss.ndim loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss
def sigmoid_focal_loss(pred, target, one_hot_target, weight=None, gamma=2.0, alpha=0.5, class_weight=None, valid_mask=None, reduction='mean', avg_factor=None): r"""A warpper of cuda version `Focal Loss <https://arxiv.org/abs/1708.02002>`_. Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes. target (torch.Tensor): The learning label of the prediction. It's shape should be (N, ) one_hot_target (torch.Tensor): The learning label with shape (N, C) weight (torch.Tensor, optional): Sample-wise loss weight. gamma (float, optional): The gamma for calculating the modulating factor. Defaults to 2.0. alpha (float | list[float], optional): A balanced form for Focal Loss. Defaults to 0.5. class_weight (list[float], optional): Weight of each class. Defaults to None. valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid samples and uses 0 to mark the ignored samples. Default: None. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ # Function.apply does not accept keyword arguments, so the decorator # "weighted_loss" is not applicable final_weight = torch.ones(1, pred.size(1)).type_as(pred) if isinstance(alpha, list): # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if # a list is given, we set the input alpha as 0.5. This means setting # equal weight for foreground class and background class. By # multiplying the loss by 2, the effect of setting alpha as 0.5 is # undone. The alpha of type list is used to regulate the loss in the # post-processing process. loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, 0.5, None, 'none') * 2 alpha = pred.new_tensor(alpha) final_weight = final_weight * (alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) else: loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, 'none') if weight is not None: if weight.shape != loss.shape and weight.size(0) == loss.size(0): # For most cases, weight is of shape (N, ), # which means it does not have the second axis num_class weight = weight.view(-1, 1) assert weight.dim() == loss.dim() final_weight = final_weight * weight if class_weight is not None: final_weight = final_weight * pred.new_tensor(class_weight) if valid_mask is not None: final_weight = final_weight * valid_mask loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) return loss