def test_config(self): normalizer = constant_op.constant([1, 3], dtype=dtypes.float32) mre_obj = metrics.MeanRelativeError(normalizer=normalizer, name='mre') self.assertEqual(mre_obj.name, 'mre') self.assertArrayNear(self.evaluate(mre_obj.normalizer), [1, 3], 1e-1) mre_obj2 = metrics.MeanRelativeError.from_config(mre_obj.get_config()) self.assertEqual(mre_obj2.name, 'mre') self.assertArrayNear(self.evaluate(mre_obj2.normalizer), [1, 3], 1e-1)
def test_zero_normalizer(self): y_pred = constant_op.constant([2, 4], dtype=dtypes.float32) y_true = constant_op.constant([1, 3]) mre_obj = metrics.MeanRelativeError(normalizer=array_ops.zeros_like(y_true)) self.evaluate(variables.variables_initializer(mre_obj.variables)) result = mre_obj(y_true, y_pred) self.assertEqual(self.evaluate(result), 0)
def test_unweighted(self): np_y_pred = np.asarray([2, 4, 6, 8], dtype=np.float32) np_y_true = np.asarray([1, 3, 2, 3], dtype=np.float32) expected_error = np.mean( np.divide(np.absolute(np_y_pred - np_y_true), np_y_true)) y_pred = constant_op.constant(np_y_pred, shape=(1, 4), dtype=dtypes.float32) y_true = constant_op.constant(np_y_true, shape=(1, 4)) mre_obj = metrics.MeanRelativeError(normalizer=y_true) self.evaluate(variables.variables_initializer(mre_obj.variables)) result = mre_obj(y_true, y_pred) self.assertAllClose(self.evaluate(result), expected_error, atol=1e-3)
def test_weighted(self): np_y_pred = np.asarray([2, 4, 6, 8], dtype=np.float32) np_y_true = np.asarray([1, 3, 2, 3], dtype=np.float32) sample_weight = np.asarray([0.2, 0.3, 0.5, 0], dtype=np.float32) rel_errors = np.divide(np.absolute(np_y_pred - np_y_true), np_y_true) expected_error = np.sum(rel_errors * sample_weight) y_pred = constant_op.constant(np_y_pred, dtype=dtypes.float32) y_true = constant_op.constant(np_y_true) mre_obj = metrics.MeanRelativeError(normalizer=y_true) self.evaluate(variables.variables_initializer(mre_obj.variables)) result = mre_obj( y_true, y_pred, sample_weight=constant_op.constant(sample_weight)) self.assertAllClose(self.evaluate(result), expected_error, atol=1e-3)