def _current_reward_metric(self):
     metric_values = online_tune.historical_metric_values(
         self._trainer.state.history,
         self._reward_metric,
     )
     assert metric_values.shape[0] > 0, (
         "No values in history for metric {}.".format(self._reward_metric))
     return metric_values[-1]
Esempio n. 2
0
 def test_clips_historical_metric_values(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "loss"), [-10, 10])
     metric_values = online_tune.historical_metric_values(
         history, metric=("train", "loss"), observation_range=(-1, 1))
     np.testing.assert_array_equal(metric_values, [-1, 1])
Esempio n. 3
0
 def test_retrieves_historical_metric_values(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "accuracy"), [0.1, 0.73])
     metric_values = online_tune.historical_metric_values(
         history, metric=("train", "accuracy"), observation_range=(0, 5))
     np.testing.assert_array_equal(metric_values, [0.1, 0.73])