Ejemplo n.º 1
0
 def test_default_args(self, inputs: ndarray, targets: ndarray):
     """ Ensures default arguments have not changed. """
     inputs = tensor(inputs)
     targets = tensor(targets, dtype=torch.long)
     assert_allclose(
         desired=softmax_focal_loss(inputs,
                                    targets,
                                    alpha=1.0,
                                    gamma=0.0,
                                    reduction="mean"),
         actual=softmax_focal_loss(inputs, targets),
         err_msg="`softmax_focal_loss default args changed",
     )
Ejemplo n.º 2
0
    def test_nan_in_grad(
        self,
        inputs: ndarray,
        alpha: float,
        gamma: float,
        dtype: torch.dtype,
        data: st.SearchStrategy,
    ):
        """ Ensures, across a wide range of inputs, that the focal loss gradient is not nan. """
        targets = data.draw(
            hnp.arrays(
                dtype=int,
                shape=(inputs.shape[0], ),
                elements=st.integers(0, inputs.shape[1] - 1),
            ),
            label="targets",
        )

        inputs = tensor(inputs, dtype=dtype, requires_grad=True)
        targets = tensor(targets, dtype=torch.long)
        loss = softmax_focal_loss(inputs, targets, alpha=alpha, gamma=gamma)

        loss.backward()
        assert not np.any(np.isnan(
            inputs.grad.numpy())), "focal loss gradient is nan"
Ejemplo n.º 3
0
    def test_reductions(self, reduction_name, reduction_function, inputs,
                        alpha, gamma, dtype, data):
        """ Ensure the reductions mean, sum, and none respectively take the mean, sum, and do not reduce. """
        targets = data.draw(
            hnp.arrays(
                dtype=int,
                shape=(inputs.shape[0], ),
                elements=st.integers(0, inputs.shape[1] - 1),
            ),
            label="targets",
        )

        inputs1 = tensor(inputs, dtype=dtype, requires_grad=True)
        inputs2 = tensor(inputs, dtype=dtype, requires_grad=True)
        targets = tensor(targets, dtype=torch.int64)

        # numerically-stable focal loss
        loss = softmax_focal_loss(inputs1,
                                  targets,
                                  alpha=alpha,
                                  gamma=gamma,
                                  reduction=reduction_name)

        # naive focal loss
        input = F.softmax(inputs2, dim=1)
        pc = input[(range(len(targets)), targets)]
        naive_loss = -alpha * (1 - pc)**gamma * torch.log(pc)
        if reduction_function is not None:
            naive_loss = reduction_function(naive_loss)

        assert_allclose(
            actual=loss.detach().numpy(),
            desired=naive_loss.detach().numpy(),
            atol=1e-5,
            rtol=1e-5,
            err_msg="focal loss with reduction='{}' does not match naive "
            "implementation on numerically-stable domain".format(
                reduction_name),
        )

        if naive_loss.dim() == 0:
            loss.backward()
            naive_loss.backward()
            assert_allclose(
                actual=inputs1.grad.numpy(),
                desired=inputs2.grad.numpy(),
                atol=1e-5,
                rtol=1e-5,
                err_msg="focal loss gradient with reduction='{}' does not "
                "match that of naive loss on numerically-stable domain".format(
                    reduction_name),
            )
Ejemplo n.º 4
0
    def test_matches_simple_implementation(
        self,
        inputs: ndarray,
        alpha: float,
        gamma: float,
        dtype: torch.dtype,
        data: st.SearchStrategy,
    ):
        """ Ensures that focal loss matches a naive-implementation over the domain where numerical
        stability is not an issue. """
        targets = data.draw(
            hnp.arrays(
                dtype=int,
                shape=(inputs.shape[0], ),
                elements=st.integers(0, inputs.shape[1] - 1),
            ),
            label="targets",
        )

        inputs1 = tensor(inputs, dtype=dtype, requires_grad=True)
        inputs2 = tensor(inputs, dtype=dtype, requires_grad=True)
        targets = tensor(targets, dtype=torch.int64)

        # numerically-stable focal loss
        loss = softmax_focal_loss(inputs1, targets, alpha=alpha, gamma=gamma)
        loss.backward()

        # naive focal loss
        input = F.softmax(inputs2, dim=1)
        pc = input[(range(len(targets)), targets)]
        naive_loss = (-alpha * (1 - pc)**gamma * torch.log(pc)).mean()
        naive_loss.backward()

        assert_allclose(
            actual=loss.detach().numpy(),
            desired=naive_loss.detach().numpy(),
            atol=1e-5,
            rtol=1e-5,
            err_msg="focal loss does not match naive implementation on "
            "numerically-stable domain",
        )
        assert_allclose(
            actual=inputs1.grad.numpy(),
            desired=inputs2.grad.numpy(),
            atol=1e-5,
            rtol=1e-5,
            err_msg="focal loss gradient does not match that of naive loss on "
            "numerically-stable domain",
        )
Ejemplo n.º 5
0
 def test_matches_binary_classification(self, pc: float, alpha: float,
                                        gamma: float):
     """ Ensures that focal loss matches the explicit binary-classification formulation from the paper. """
     loss = -alpha * (1 - pc)**gamma * np.log(pc)
     inputs = tensor([[np.log(pc), np.log(1 - pc)]])
     targets = tensor([0])
     assert_allclose(
         desired=loss,
         actual=softmax_focal_loss(inputs,
                                   targets,
                                   alpha=alpha,
                                   gamma=gamma),
         atol=1e-5,
         rtol=1e-5,
         err_msg="focal loss does not reduce to binary-classification form",
     )
Ejemplo n.º 6
0
    def test_matches_crossentropy(self, inputs: ndarray, alpha: float,
                                  dtype: torch.dtype, data: st.SearchStrategy):
        """ Ensures that focal loss w/ gamma=0 matches softmax cross-entropy (scaled by alpha). """
        targets = data.draw(
            hnp.arrays(
                dtype=int,
                shape=(inputs.shape[0], ),
                elements=st.integers(0, inputs.shape[1] - 1),
            ),
            label="targets",
        )

        inputs = tensor(inputs, dtype=dtype)
        targets = tensor(targets, dtype=torch.long)
        assert_allclose(
            desired=alpha * F.cross_entropy(inputs, targets),
            actual=softmax_focal_loss(inputs, targets, alpha=alpha, gamma=0.0),
            atol=1e-6,
            rtol=1e-6,
            err_msg=
            "Focal loss with gamma=0 fails to match cross-entropy loss.",
        )
Ejemplo n.º 7
0
 def test_valid_args(self, reduction: str):
     """ Ensures that invalid arguments raise a value error. """
     inputs = tensor(np.random.rand(5, 5))
     targets = tensor(np.random.randint(0, 1, 5))
     with pytest.raises(ValueError):
         softmax_focal_loss(inputs, targets, reduction=reduction)