コード例 #1
0
ファイル: test_utils.py プロジェクト: asmith26/jax_toolkit
    def test_supported_loss_returns_correctly_with_loss_kwargs(self):
        import haiku as hk

        def net_function(x: jnp.ndarray) -> jnp.ndarray:
            net = hk.Sequential([])
            return net(x)

        net_transform = hk.transform(net_function)
        actual_loss_function_wrapper = get_haiku_loss_function(
            net_transform,
            loss="sigmoid_focal_crossentropy",
            alpha=None,
            gamma=None)

        # Check works
        rng = jax.random.PRNGKey(42)
        params = net_transform.init(rng, jnp.array(0))

        self.assertEqual(
            0,
            actual_loss_function_wrapper(params,
                                         x=jnp.array([0]),
                                         y_true=jnp.array([0])))
        self.assertEqual(
            0,
            actual_loss_function_wrapper(params,
                                         x=jnp.array([1]),
                                         y_true=jnp.array([1])))
        # When alpha and gamma are None, it should be equal to log_loss
        self.assertEqual(
            log_loss(y_true=jnp.array([0.27]), y_pred=jnp.array([0])),
            actual_loss_function_wrapper(params,
                                         x=jnp.array([0]),
                                         y_true=jnp.array([0.27])),
        )
        self.assertEqual(
            log_loss(y_true=jnp.array([0.97]), y_pred=jnp.array([1])),
            actual_loss_function_wrapper(params,
                                         x=jnp.array([1]),
                                         y_true=jnp.array([0.97])),
        )
コード例 #2
0
 def test_multiclass_returns_correctly(self):
     actual_loss = log_loss(jnp.array([[0, 1], [1, 0]]),
                            jnp.array([[0, 1], [1, 0]]))
     self.assertEqual(0, actual_loss)
     actual_loss = log_loss(jnp.array([[0, 1], [1, 0]]),
                            jnp.array([[0, 1], [1, 0]]),
                            normalize=False)
     np.testing.assert_array_equal(jnp.array([0, 0]), actual_loss)
     # Based on scikit-learn: https://github.com/scikit-learn/scikit-learn/blob
     # /ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn/metrics/tests/test_classification.py#L2135
     actual_loss = log_loss(
         jnp.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]),
         jnp.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]))
     self.assertEqual(0.69049114, actual_loss)
     actual_loss = log_loss(
         jnp.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]),
         jnp.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]),
         normalize=False,
     )
     np.testing.assert_array_equal(
         jnp.array([0.35667497, 0.5108256, 1.2039728]), actual_loss)
コード例 #3
0
 def test_binary_returns_correctly(self):
     actual_loss = sigmoid_focal_crossentropy(jnp.array([1]),
                                              jnp.array([1]))
     self.assertEqual(0, actual_loss)
     actual_loss = sigmoid_focal_crossentropy(jnp.array([0]),
                                              jnp.array([0]))
     self.assertEqual(0, actual_loss)
     actual_loss = sigmoid_focal_crossentropy(jnp.array([1, 0]),
                                              jnp.array([1, 0]))
     self.assertEqual(0, actual_loss)
     # Based on tensorflow_addons: https://github.com/tensorflow/addons/blob/v0.10.0/tensorflow_addons/losses
     # /tests/focal_loss_test.py#L106
     actual_loss = sigmoid_focal_crossentropy(
         jnp.array([1, 1, 1, 0, 0, 0]),
         jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]),
         alpha=None,
         gamma=None)
     # When alpha and gamma are None, it should be equal to log_loss
     expected_loss = log_loss(
         jnp.array([1, 1, 1, 0, 0, 0]),
         jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]))
     self.assertEqual(expected_loss, actual_loss)
     actual_loss = sigmoid_focal_crossentropy(
         jnp.array([1, 1, 1, 0, 0, 0]),
         jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]),
         alpha=None,
         gamma=2.0)
     self.assertEqual(0.007911247, actual_loss)
     actual_loss = sigmoid_focal_crossentropy(
         jnp.array([1, 1, 1, 0, 0, 0]),
         jnp.array([0.97, 0.91, 0.73, 0.27, 0.09, 0.03]),
         alpha=None,
         gamma=2.0,
         normalize=False,
     )
     np.testing.assert_array_equal(
         jnp.array([
             2.7413207e-05, 7.6391589e-04, 2.2942409e-02, 2.2942409e-02,
             7.6391734e-04, 2.7413207e-05
         ]),
         actual_loss,
     )
コード例 #4
0
 def test_binary_returns_correctly(self):
     actual_loss = log_loss(jnp.array([1]), jnp.array([1]))
     self.assertEqual(0, actual_loss)
     actual_loss = log_loss(jnp.array([0]), jnp.array([0]))
     self.assertEqual(0, actual_loss)
     actual_loss = log_loss(jnp.array([1, 0]), jnp.array([1, 0]))
     self.assertEqual(0, actual_loss)
     actual_loss = log_loss(jnp.array([1, 0]), jnp.array([1, 1]))
     self.assertEqual(17.269388, actual_loss)
     # Based on scikit-learn: https://github.com/scikit-learn/scikit-learn/blob
     # /ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn/metrics/tests/test_classification.py#L2135
     actual_loss = log_loss(jnp.array([0, 0, 0, 1, 1, 1]),
                            jnp.array([0.5, 0.9, 0.99, 0.1, 0.25, 0.999]))
     self.assertEqual(1.8817972, actual_loss)
     actual_loss = log_loss(jnp.array([0, 0, 0, 1, 1, 1]),
                            jnp.array([0.5, 0.9, 0.99, 0.1, 0.25, 0.999]),
                            normalize=False)
     np.testing.assert_array_equal(
         jnp.array([
             6.9314718e-01, 2.3025849e00, 4.6051712e00, 2.3025851e00,
             1.3862944e00, 1.0004875e-03
         ]),
         actual_loss,
     )
コード例 #5
0
 def test_raises_when_attempt_to_use_not_one_hot_encoded_multiclass(self):
     with self.assertRaises(TypeError) as _:
         log_loss(
             jnp.array([1, 0, 2]),
             jnp.array([[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]))
コード例 #6
0
 def test_raises_when_number_of_multiclass_classes_not_equal(self):
     with self.assertRaises(TypeError) as _:
         log_loss(jnp.array([[0, 0, 1], [0, 0, 1], [0, 0, 1]]),
                  jnp.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]))
コード例 #7
0
 def test_raises_when_number_of_samples_not_equal_multiclass(self):
     with self.assertRaises(ValueError) as _:
         log_loss(jnp.array([[0, 1], [1, 0]]),
                  jnp.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]))
コード例 #8
0
 def test_raises_when_number_of_samples_not_equal(self):
     with self.assertRaises(ValueError) as _:
         log_loss(jnp.array([0, 1]), jnp.array([0, 1, 0]))