def test_trimmed_error(self): y_pred = tf.random.shuffle( tf.reshape(tf.range(0, 21, dtype=tf.float32), (-1, 1))) y_true = tf.zeros((21, ), dtype=tf.float32) error = metrics.trimmed_error(y_true, y_pred, 0.10, 0.40) selected = list(range(2, 8)) self.assertAllClose(error, np.mean(selected))
def call(self, y_true, y_pred): return tf.reduce_mean( metrics.trimmed_error(y_true, y_pred, self._start_quantile, self._end_quantile, power=self._power))