def test_trimmed_error(self):
     y_pred = tf.random.shuffle(
         tf.reshape(tf.range(0, 21, dtype=tf.float32), (-1, 1)))
     y_true = tf.zeros((21, ), dtype=tf.float32)
     error = metrics.trimmed_error(y_true, y_pred, 0.10, 0.40)
     selected = list(range(2, 8))
     self.assertAllClose(error, np.mean(selected))
Esempio n. 2
0
 def call(self, y_true, y_pred):
     return tf.reduce_mean(
         metrics.trimmed_error(y_true,
                               y_pred,
                               self._start_quantile,
                               self._end_quantile,
                               power=self._power))