def test_sqrt_hessian_mse_should_pass(extension):
    loss = dummy_mse()

    with bp(extension()):
        loss.backward()
def test_sqrt_hessian_modified_mse_should_fail(extension):
    loss = dummy_mse() * 2

    with pytest.warns(UserWarning):
        with bp(extension()):
            loss.backward()
def test_sqrt_hessian_crossentropy_should_pass(extension):
    loss = dummy_cross_entropy()

    with bp(extension()):
        loss.backward()