def test_supported_loss_returns_correctly_with_loss_kwargs(self): import haiku as hk def net_function(x: jnp.ndarray) -> jnp.ndarray: net = hk.Sequential([]) return net(x) net_transform = hk.transform(net_function) actual_loss_function_wrapper = get_haiku_loss_function( net_transform, loss="sigmoid_focal_crossentropy", alpha=None, gamma=None) # Check works rng = jax.random.PRNGKey(42) params = net_transform.init(rng, jnp.array(0)) self.assertEqual( 0, actual_loss_function_wrapper(params, x=jnp.array([0]), y_true=jnp.array([0]))) self.assertEqual( 0, actual_loss_function_wrapper(params, x=jnp.array([1]), y_true=jnp.array([1]))) # When alpha and gamma are None, it should be equal to log_loss self.assertEqual( log_loss(y_true=jnp.array([0.27]), y_pred=jnp.array([0])), actual_loss_function_wrapper(params, x=jnp.array([0]), y_true=jnp.array([0.27])), ) self.assertEqual( log_loss(y_true=jnp.array([0.97]), y_pred=jnp.array([1])), actual_loss_function_wrapper(params, x=jnp.array([1]), y_true=jnp.array([0.97])), )
def test_multiclass_returns_correctly(self): actual_loss = log_loss(jnp.array([[0, 1], [1, 0]]), jnp.array([[0, 1], [1, 0]])) self.assertEqual(0, actual_loss) actual_loss = log_loss(jnp.array([[0, 1], [1, 0]]), jnp.array([[0, 1], [1, 0]]), normalize=False) np.testing.assert_array_equal(jnp.array([0, 0]), actual_loss) # Based on scikit-learn: https://github.com/scikit-learn/scikit-learn/blob # /ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn/metrics/tests/test_classification.py#L2135 actual_loss = log_loss( jnp.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), jnp.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]])) self.assertEqual(0.69049114, actual_loss) actual_loss = log_loss( jnp.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]), jnp.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]), normalize=False, ) np.testing.assert_array_equal( jnp.array([0.35667497, 0.5108256, 1.2039728]), actual_loss)
def test_binary_returns_correctly(self): actual_loss = sigmoid_focal_crossentropy(jnp.array([1]), jnp.array([1])) self.assertEqual(0, actual_loss) actual_loss = sigmoid_focal_crossentropy(jnp.array([0]), jnp.array([0])) self.assertEqual(0, actual_loss) actual_loss = sigmoid_focal_crossentropy(jnp.array([1, 0]), jnp.array([1, 0])) self.assertEqual(0, actual_loss) # Based on tensorflow_addons: https://github.com/tensorflow/addons/blob/v0.10.0/tensorflow_addons/losses # /tests/focal_loss_test.py#L106 actual_loss = sigmoid_focal_crossentropy( jnp.array([1, 1, 1, 0, 0, 0]), jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]), alpha=None, gamma=None) # When alpha and gamma are None, it should be equal to log_loss expected_loss = log_loss( jnp.array([1, 1, 1, 0, 0, 0]), jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03])) self.assertEqual(expected_loss, actual_loss) actual_loss = sigmoid_focal_crossentropy( jnp.array([1, 1, 1, 0, 0, 0]), jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]), alpha=None, gamma=2.0) self.assertEqual(0.007911247, actual_loss) actual_loss = sigmoid_focal_crossentropy( jnp.array([1, 1, 1, 0, 0, 0]), jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]), alpha=None, gamma=2.0, normalize=False, ) np.testing.assert_array_equal( jnp.array([ 2.7413207e-05, 7.6391589e-04, 2.2942409e-02, 2.2942409e-02, 7.6391734e-04, 2.7413207e-05 ]), actual_loss, )
def test_binary_returns_correctly(self): actual_loss = log_loss(jnp.array([1]), jnp.array([1])) self.assertEqual(0, actual_loss) actual_loss = log_loss(jnp.array([0]), jnp.array([0])) self.assertEqual(0, actual_loss) actual_loss = log_loss(jnp.array([1, 0]), jnp.array([1, 0])) self.assertEqual(0, actual_loss) actual_loss = log_loss(jnp.array([1, 0]), jnp.array([1, 1])) self.assertEqual(17.269388, actual_loss) # Based on scikit-learn: https://github.com/scikit-learn/scikit-learn/blob # /ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn/metrics/tests/test_classification.py#L2135 actual_loss = log_loss(jnp.array([0, 0, 0, 1, 1, 1]), jnp.array([0.5, 0.9, 0.99, 0.1, 0.25, 0.999])) self.assertEqual(1.8817972, actual_loss) actual_loss = log_loss(jnp.array([0, 0, 0, 1, 1, 1]), jnp.array([0.5, 0.9, 0.99, 0.1, 0.25, 0.999]), normalize=False) np.testing.assert_array_equal( jnp.array([ 6.9314718e-01, 2.3025849e00, 4.6051712e00, 2.3025851e00, 1.3862944e00, 1.0004875e-03 ]), actual_loss, )
def test_raises_when_attempt_to_use_not_one_hot_encoded_multiclass(self): with self.assertRaises(TypeError) as _: log_loss( jnp.array([1, 0, 2]), jnp.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]))
def test_raises_when_number_of_multiclass_classes_not_equal(self): with self.assertRaises(TypeError) as _: log_loss(jnp.array([[0, 0, 1], [0, 0, 1], [0, 0, 1]]), jnp.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]))
def test_raises_when_number_of_samples_not_equal_multiclass(self): with self.assertRaises(ValueError) as _: log_loss(jnp.array([[0, 1], [1, 0]]), jnp.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]))
def test_raises_when_number_of_samples_not_equal(self): with self.assertRaises(ValueError) as _: log_loss(jnp.array([0, 1]), jnp.array([0, 1, 0]))