def _setup_metrics(self): self.metric_functions[LOSS] = self.eval_loss_function self.metric_functions[ERROR] = ErrorScore(name='metric_error') self.metric_functions[MEAN_SQUARED_ERROR] = MeanSquaredErrorMetric( name='metric_mse') self.metric_functions[MEAN_ABSOLUTE_ERROR] = MeanAbsoluteErrorMetric( name='metric_mae') self.metric_functions[R2] = R2Score(name='metric_r2')
def _setup_metrics(self): self.metric_functions = {} # needed to shadow class variable if self.loss[TYPE] == 'mean_squared_error': self.metric_functions[LOSS] = MSEMetric(name='eval_loss') else: self.metric_functions[LOSS] = MAEMetric(name='eval_loss') self.metric_functions[ERROR] = ErrorScore(name='metric_error') self.metric_functions[MEAN_SQUARED_ERROR] = MeanSquaredErrorMetric( name='metric_mse') self.metric_functions[MEAN_ABSOLUTE_ERROR] = MeanAbsoluteErrorMetric( name='metric_mae') self.metric_functions[R2] = R2Score(name='metric_r2')
def test_ErrorScore(generated_data): error_score = ErrorScore() assert np.isnan(error_score.result().numpy()) # test as single batch error_score.update_state(generated_data.y_true, generated_data.y_good) good_single_batch = error_score.result().numpy() assert np.isreal(good_single_batch) # test as two batches error_score.reset_states() error_score.update_state(generated_data.y_true[:SPLIT_POINT], generated_data.y_good[:SPLIT_POINT]) error_score.update_state(generated_data.y_true[SPLIT_POINT:], generated_data.y_good[SPLIT_POINT:]) good_two_batch = error_score.result().numpy() assert np.isreal(good_two_batch) # single batch and multi-batch should be very close assert np.isclose(good_single_batch, good_two_batch) # test for bad predictions error_score.reset_states() error_score.update_state(generated_data.y_true[:SPLIT_POINT], generated_data.y_bad[:SPLIT_POINT]) error_score.update_state(generated_data.y_true[SPLIT_POINT:], generated_data.y_bad[SPLIT_POINT:]) bad_prediction_score = error_score.result().numpy() # magnitude of bad predictions should be greater than good predictions assert np.abs(bad_prediction_score) > np.abs(good_two_batch)