def test_multioutput_returns_correctly(self): actual_loss = mean_squared_error(jnp.array([[0, 1, 2], [0, 1, 2]]), jnp.array([[0, 1, 2], [0, 1, 2]])) self.assertEqual(0, actual_loss) actual_loss = mean_squared_error( jnp.array([[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 0, 0]]), jnp.array([[1, 2, 0, 1], [4, 3, 1, 1], [0, 0, 0, 1]])) self.assertEqual(1.8333334, actual_loss)
def test_single_output_returns_correctly(self): actual_loss = mean_squared_error(jnp.array([0, 1, 2]), jnp.array([0, 1, 2])) self.assertEqual(0, actual_loss) actual_loss = mean_squared_error(jnp.array([0, 1, 2]), jnp.array([0, 0, 0])) self.assertEqual(1.6666667, actual_loss) # Based on scikit-learn: https://github.com/scikit-learn/scikit-learn/blob # /ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn/metrics/tests/test_regression.py#L25 y_true = jnp.arange(50) y_pred = y_true + 1 actual_loss = mean_squared_error(y_true, y_pred) self.assertEqual(1, actual_loss)
def test_raises_when_number_of_multioutput_outputs_not_equal(self): with self.assertRaises(TypeError) as _: mean_squared_error(jnp.array([[1, 2, 3], [10, 11, 12]]), jnp.array([[0.2, 0.7], [0.6, 0.5]]))
def test_raises_when_number_of_samples_not_equal_multioutput(self): with self.assertRaises(TypeError) as _: mean_squared_error(jnp.array([[0, 1], [1, 2]]), jnp.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]))
def test_raises_when_number_of_samples_not_equal(self): with self.assertRaises(TypeError) as _: mean_squared_error(jnp.array([0, 0]), jnp.array([0, 0, 0]))