def test_sigmoid_focal_loss():
    input_good = torch.Tensor([10, -10, 10]).float()
    input_bad = torch.Tensor([-1, 2, 0]).float()
    target = torch.Tensor([1, 0, 1])

    loss_good = F.sigmoid_focal_loss(input_good, target)
    loss_bad = F.sigmoid_focal_loss(input_bad, target)
    assert loss_good < loss_bad
Exemple #2
0
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        if self.ignore_index is not None:
            mask = target != self.ignore_index
            target = target[mask]
            input = input[mask]

        if not len(target):
            return torch.tensor(0.).to(input.device)

        focal_loss = 0
        num_classes = input.size(1)
        for cls in range(num_classes):
            cls_label_target = (target == cls).long()
            cls_label_input = input[:, cls]
            focal_loss += sigmoid_focal_loss(cls_label_input,
                                             cls_label_target,
                                             gamma=self.gamma,
                                             alpha=None)

        # Second term
        y = F.log_softmax(input, dim=1).exp()
        target_one_hot = F.one_hot(target, input.size(1)).float()
        # +1 to make loss be [0;2], instead [-1;1]
        kappa_loss = 1 + quad_kappa_loss_v2(
            y, target_one_hot, y_pow=self.y_pow, eps=self.eps)

        return kappa_loss + self.log_scale * focal_loss
Exemple #3
0
    def forward(self, label_input, label_target):
        num_classes = label_input.size(1)
        loss = 0

        # Filter anchors with -1 label from loss computation
        if self.ignore_index is not None:
            not_ignored = label_target != self.ignore_index
            label_input = label_input[not_ignored]
            label_target = label_target[not_ignored]

        for cls in range(num_classes):
            cls_label_target = (label_target == cls).long()
            cls_label_input = label_input[:, cls]

            loss += sigmoid_focal_loss(cls_label_input, cls_label_target, gamma=self.gamma, alpha=self.alpha)
        return loss