def test_multioutput_returns_correctly(self):
     actual_loss = mean_squared_error(jnp.array([[0, 1, 2], [0, 1, 2]]),
                                      jnp.array([[0, 1, 2], [0, 1, 2]]))
     self.assertEqual(0, actual_loss)
     actual_loss = mean_squared_error(
         jnp.array([[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 0, 0]]),
         jnp.array([[1, 2, 0, 1], [4, 3, 1, 1], [0, 0, 0, 1]]))
     self.assertEqual(1.8333334, actual_loss)
 def test_single_output_returns_correctly(self):
     actual_loss = mean_squared_error(jnp.array([0, 1, 2]),
                                      jnp.array([0, 1, 2]))
     self.assertEqual(0, actual_loss)
     actual_loss = mean_squared_error(jnp.array([0, 1, 2]),
                                      jnp.array([0, 0, 0]))
     self.assertEqual(1.6666667, actual_loss)
     # Based on scikit-learn: https://github.com/scikit-learn/scikit-learn/blob
     # /ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn/metrics/tests/test_regression.py#L25
     y_true = jnp.arange(50)
     y_pred = y_true + 1
     actual_loss = mean_squared_error(y_true, y_pred)
     self.assertEqual(1, actual_loss)
 def test_raises_when_number_of_multioutput_outputs_not_equal(self):
     with self.assertRaises(TypeError) as _:
         mean_squared_error(jnp.array([[1, 2, 3], [10, 11, 12]]),
                            jnp.array([[0.2, 0.7], [0.6, 0.5]]))
 def test_raises_when_number_of_samples_not_equal_multioutput(self):
     with self.assertRaises(TypeError) as _:
         mean_squared_error(jnp.array([[0, 1], [1, 2]]),
                            jnp.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]))
 def test_raises_when_number_of_samples_not_equal(self):
     with self.assertRaises(TypeError) as _:
         mean_squared_error(jnp.array([0, 0]), jnp.array([0, 0, 0]))