Ejemplo n.º 1
0
    def test_gmae_random_arrays_finite_values(self, y_true, y_pred):
        gmae_value = gmae(y_true, y_pred)
        expected_gmae = self._correct_gmae(y_true, y_pred)
        print(y_true)
        print(y_pred)

        assert expected_gmae == gmae_value
Ejemplo n.º 2
0
    def test_gmae_dataframe(self):
        y_true = pd.DataFrame([0, 1, 2, 3, 6, 5])
        y_pred = pd.DataFrame([-1, 4, 5, 10, 4, 1])

        gmae_value = np.round(gmae(y_true, y_pred), decimals=2)
        expected_gmae = 2.82

        assert expected_gmae == gmae_value
Ejemplo n.º 3
0
    def test_gmae_array(self):
        y_true = np.array([0, 1, 2, 3, 6, 5])
        y_pred = np.array([-1, 4, 5, 10, 4, 1])

        gmae_value = np.round(gmae(y_true, y_pred), decimals=2)
        expected_gmae = 2.82

        assert expected_gmae == gmae_value
Ejemplo n.º 4
0
    def test_gmae_list(self):
        y_true = [0, 1, 2, 3, 6, 5]
        y_pred = [-1, 4, 5, 10, 4, 1]

        gmae_value = np.round(gmae(y_true, y_pred), decimals=2)
        expected_gmae = 2.82

        assert expected_gmae == gmae_value
Ejemplo n.º 5
0
    def test_zero_in_difference_gmae(self):
        # if absolute difference is zero then GMAE is zero
        y_true = pd.DataFrame([0, 1, 2, 3, 4, 5])
        y_pred = pd.DataFrame([-1, 4, 5, 10, 4, 1])

        gmae_value = np.round(gmae(y_true, y_pred), decimals=2)
        expected_gmae = 0

        assert expected_gmae == gmae_value
Ejemplo n.º 6
0
    def test_infinite_values(self):
        y_true = np.random.random(4)
        y_pred = [0, np.inf, 2, 3]

        with pytest.raises(ValueError):
            gmae(y_true, y_pred)
Ejemplo n.º 7
0
    def test_nan_values(self):
        y_true = [np.nan, 1, 2, 3]
        y_pred = np.random.random(4)

        with pytest.raises(ValueError):
            gmae(y_true, y_pred)
Ejemplo n.º 8
0
    def test_wrong_vector_length(self):
        y_true = np.random.random(5)
        y_pred = np.random.random(4)

        with pytest.raises(ValueError):
            gmae(y_true, y_pred)