Esempio n. 1
0
 def test_works_with_multiple_controls(self):
     history = trax_history.History()
     history.append('eval', 'metrics/accuracy', step=0, value=0.8)
     history.append(*online_tune.control_metric('learning_rate'),
                    step=0,
                    value=1e-4)
     history.append(*online_tune.control_metric('weight_decay_rate'),
                    step=0,
                    value=1e-5)
     schedule = self._make_schedule(
         history,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         control_configs=(
             ('learning_rate', 1e-3, (1e-9, 1.0), False),
             ('weight_decay_rate', 1e-5, (1e-9, 1.0), False),
         ),
         action_multipliers=(1.0, ),
     )
     new_controls = schedule(123)
     self.assertIn('learning_rate', new_controls)
     self.assertIn('weight_decay_rate', new_controls)
Esempio n. 2
0
 def evaluate(self, eval_steps):
     del eval_steps
     self.state.history.append(mode=HISTORY_MODE,
                               metric=METRIC,
                               step=self.step,
                               value=self.metrics_to_report.pop(0))
     for (name, value) in self.nontrainable_params.items():
         (mode, metric) = online_tune.control_metric(name)
         self.state.history.append(mode=mode,
                                   metric=metric,
                                   step=self.step,
                                   value=value)
Esempio n. 3
0
 def test_clips_updated_control_with_flipping(self):
     config = ("momentum", None, (0.5, 0.99), True)
     history = trax_history.History()
     self._append_metrics(history, online_tune.control_metric("momentum"),
                          [0.985])
     new_control = online_tune.update_control(
         control_config=config,
         action=0,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
     )
     np.testing.assert_almost_equal(new_control, 0.99)
Esempio n. 4
0
 def test_clips_updated_control_without_flipping(self):
     config = ("learning_rate", None, (1e-9, 10.0), False)
     history = trax_history.History()
     self._append_metrics(history,
                          online_tune.control_metric("learning_rate"),
                          [7.0])
     new_control = online_tune.update_control(
         control_config=config,
         action=2,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
     )
     np.testing.assert_almost_equal(new_control, 10.0)
Esempio n. 5
0
 def test_changes_lr_when_there_are_some_metrics(self):
     history = trax_history.History()
     history.append('eval', 'metrics/accuracy', step=0, value=0.8)
     history.append(*online_tune.control_metric('learning_rate'),
                    step=0,
                    value=1e-4)
     schedule = self._make_schedule(
         history,
         control_configs=(('learning_rate', 1e-3, (1e-9, 1.0), False), ),
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(0.5, 2.0),
     )
     new_lr = schedule(123)['learning_rate']
     self.assertTrue(np.allclose(new_lr, 5e-5) or np.allclose(new_lr, 2e-4))
Esempio n. 6
0
 def test_works_with_serialized_policy(self):
     history = trax_history.History()
     history.append('eval', 'metrics/accuracy', step=0, value=0.8)
     history.append(*online_tune.control_metric('learning_rate'),
                    step=0,
                    value=1e-4)
     schedule = self._make_schedule(
         history,
         control_configs=(('learning_rate', 1e-3, (1e-9, 1.0), False), ),
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(0.5, 2.0),
         vocab_size=16,
     )
     new_lr = schedule(123)['learning_rate']
     self.assertTrue(np.allclose(new_lr, 5e-5) or np.allclose(new_lr, 2e-4))