def evaluate(self, eval_steps): del eval_steps self.state.history.append(mode=HISTORY_MODE, metric=METRIC, step=self.step, value=self.metrics_to_report.pop(0)) for (name, value) in self.nontrainable_params.items(): (mode, metric) = online_tune.control_metric(name) self.state.history.append(mode=mode, metric=metric, step=self.step, value=value)
def test_clips_updated_control_with_flipping(self): config = ("momentum", None, (0.5, 0.99), True) history = trax_history.History() self._append_metrics(history, online_tune.control_metric("momentum"), [0.985]) new_control = online_tune.update_control( control_config=config, action=0, history=history, action_multipliers=(0.5, 1.0, 2.0), ) np.testing.assert_almost_equal(new_control, 0.99)
def test_works_with_multiple_controls(self): history = trax_history.History() history.append("eval", "metrics/accuracy", step=0, value=0.8) history.append( *online_tune.control_metric("learning_rate"), step=0, value=1e-4 ) history.append( *online_tune.control_metric("weight_decay_rate"), step=0, value=1e-5 ) schedule = self._make_schedule( history, observation_metrics=(("eval", "metrics/accuracy"),), control_configs=( ("learning_rate", 1e-3, (1e-9, 1.0), False), ("weight_decay_rate", 1e-5, (1e-9, 1.0), False), ), action_multipliers=(1.0,), ) new_controls = schedule(123) self.assertIn("learning_rate", new_controls) self.assertIn("weight_decay_rate", new_controls)
def test_clips_updated_control_without_flipping(self): config = ("learning_rate", None, (1e-9, 10.0), False) history = trax_history.History() self._append_metrics(history, online_tune.control_metric("learning_rate"), [7.0]) new_control = online_tune.update_control( control_config=config, action=2, history=history, action_multipliers=(0.5, 1.0, 2.0), ) np.testing.assert_almost_equal(new_control, 10.0)
def test_changes_lr_when_there_are_some_metrics(self): history = trax_history.History() history.append("eval", "metrics/accuracy", step=0, value=0.8) history.append(*online_tune.control_metric("learning_rate"), step=0, value=1e-4) schedule = self._make_schedule( history, observation_metrics=(("eval", "metrics/accuracy"), ), action_multipliers=(0.5, 2.0), ) new_lr = schedule(123)["learning_rate"] self.assertTrue( onp.allclose(new_lr, 5e-5) or onp.allclose(new_lr, 2e-4))