Example #1
0
 def evaluate(self, eval_steps):
     del eval_steps
     self.state.history.append(mode=HISTORY_MODE,
                               metric=METRIC,
                               step=self.step,
                               value=self.metrics_to_report.pop(0))
     for (name, value) in self.nontrainable_params.items():
         (mode, metric) = online_tune.control_metric(name)
         self.state.history.append(mode=mode,
                                   metric=metric,
                                   step=self.step,
                                   value=value)
Example #2
0
 def test_clips_updated_control_with_flipping(self):
     config = ("momentum", None, (0.5, 0.99), True)
     history = trax_history.History()
     self._append_metrics(history, online_tune.control_metric("momentum"),
                          [0.985])
     new_control = online_tune.update_control(
         control_config=config,
         action=0,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
     )
     np.testing.assert_almost_equal(new_control, 0.99)
 def test_works_with_multiple_controls(self):
   history = trax_history.History()
   history.append("eval", "metrics/accuracy", step=0, value=0.8)
   history.append(
       *online_tune.control_metric("learning_rate"), step=0, value=1e-4
   )
   history.append(
       *online_tune.control_metric("weight_decay_rate"), step=0, value=1e-5
   )
   schedule = self._make_schedule(
       history,
       observation_metrics=(("eval", "metrics/accuracy"),),
       control_configs=(
           ("learning_rate", 1e-3, (1e-9, 1.0), False),
           ("weight_decay_rate", 1e-5, (1e-9, 1.0), False),
       ),
       action_multipliers=(1.0,),
   )
   new_controls = schedule(123)
   self.assertIn("learning_rate", new_controls)
   self.assertIn("weight_decay_rate", new_controls)
Example #4
0
 def test_clips_updated_control_without_flipping(self):
     config = ("learning_rate", None, (1e-9, 10.0), False)
     history = trax_history.History()
     self._append_metrics(history,
                          online_tune.control_metric("learning_rate"),
                          [7.0])
     new_control = online_tune.update_control(
         control_config=config,
         action=2,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
     )
     np.testing.assert_almost_equal(new_control, 10.0)
Example #5
0
 def test_changes_lr_when_there_are_some_metrics(self):
     history = trax_history.History()
     history.append("eval", "metrics/accuracy", step=0, value=0.8)
     history.append(*online_tune.control_metric("learning_rate"),
                    step=0,
                    value=1e-4)
     schedule = self._make_schedule(
         history,
         observation_metrics=(("eval", "metrics/accuracy"), ),
         action_multipliers=(0.5, 2.0),
     )
     new_lr = schedule(123)["learning_rate"]
     self.assertTrue(
         onp.allclose(new_lr, 5e-5) or onp.allclose(new_lr, 2e-4))