def reset(self, output_dir, init_checkpoint=None): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. init_checkpoint: Initial checkpoint to use (default $output_dir/model.pkl) """ self.close() self._output_dir = output_dir if output_dir is not None: tf.io.gfile.makedirs(output_dir) else: assert not self._should_save_checkpoints assert not self._should_write_summaries # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'), enable=self._is_chief) self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'), enable=self._is_chief) # Reset the train and eval streams. self._train_stream = _repeat_stream(self._inputs.train_stream, self._n_devices) # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval # set by adding a padding and stopping the stream when too large. self._eval_stream = _repeat_stream( self._inputs.eval_stream, self._n_devices) self._train_eval_stream = _repeat_stream( self._inputs.train_eval_stream, self._n_devices) # Restore the training state. if output_dir is not None: state = load_trainer_state(output_dir, init_checkpoint) else: state = TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._new_opt_state_and_model_state() model_state = self._for_n_devices(model_state) self._opt_state = OptState(*self._for_n_devices(opt_state)) self._model_state = model_state if not state.opt_state and self._should_save_checkpoints: self.save_state(keep=False) self.update_nontrainable_params()
def test_returns_start_lr_when_there_are_no_metrics(self): history = trax_history.History() start_lr = 1e-3 schedule = self._make_schedule( history, control_configs=(('learning_rate', start_lr, (1e-9, 1.0), False), ), ) self.assertEqual(schedule(0)['learning_rate'], start_lr)
def test_clips_observations(self): history = trax_history.History() self._append_metrics(history, ('eval', 'loss'), [-10, 10]) observations = online_tune.history_to_observations( history, metrics=(('eval', 'loss'), ), observation_range=(-2, 2), include_lr=False, ) np.testing.assert_array_equal(observations, [[-2], [2]])
def test_clips_observations(self): history = trax_history.History() self._append_metrics(history, ("eval", "loss"), [-10, 10]) observations = online_tune.history_to_observations( history, metrics=(("eval", "loss"), ), observation_range=(-2, 2), control_configs=None, ) np.testing.assert_array_equal(observations, [[-2], [2]])
def test_clips_new_learning_rate(self): history = trax_history.History() self._append_metrics(history, online_tune.LEARNING_RATE_METRIC, [1e-3]) new_lr = online_tune.new_learning_rate( action=0, history=history, action_multipliers=(4.0, 1.0, 0.25), max_lr=3e-3, ) np.testing.assert_almost_equal(new_lr, 3e-3)
def test_calculates_new_learning_rate(self): history = trax_history.History() self._append_metrics(history, online_tune.LEARNING_RATE_METRIC, [1e-2, 1e-3]) new_lr = online_tune.new_learning_rate( action=2, history=history, action_multipliers=(0.5, 1.0, 2.0), max_lr=1.0, ) np.testing.assert_almost_equal(new_lr, 2e-3)
def test_converts_history_to_observations_without_controls(self): history = trax_history.History() self._append_metrics(history, ("train", "loss"), [1.0, 0.07]) self._append_metrics(history, ("eval", "accuracy"), [0.12, 0.68]) observations = online_tune.history_to_observations( history, metrics=(("eval", "accuracy"), ("train", "loss")), observation_range=(-1, 1), control_configs=None, ) np.testing.assert_array_almost_equal(observations, [[0.12, 1.0], [0.68, 0.07]])
def test_clips_updated_control_with_flipping(self): config = ("momentum", None, (0.5, 0.99), True) history = trax_history.History() self._append_metrics(history, online_tune.control_metric("momentum"), [0.985]) new_control = online_tune.update_control( control_config=config, action=0, history=history, action_multipliers=(0.5, 1.0, 2.0), ) np.testing.assert_almost_equal(new_control, 0.99)
def test_converts_history_to_observations_without_learning_rate(self): history = trax_history.History() self._append_metrics(history, ('train', 'loss'), [3.0, 1.07]) self._append_metrics(history, ('eval', 'accuracy'), [0.12, 0.68]) observations = online_tune.history_to_observations( history, metrics=(('eval', 'accuracy'), ('train', 'loss')), observation_range=(0, 5), include_lr=False, ) np.testing.assert_array_equal(observations, [[0.12, 3.0], [0.68, 1.07]])
def test_converts_history_to_observations_with_learning_rate(self): history = trax_history.History() self._append_metrics(history, ('train', 'training/learning_rate'), [1e-3, 1e-4]) observations = online_tune.history_to_observations( history, metrics=(), observation_range=(0, 5), include_lr=True, ) self.assertEqual(observations.shape, (2, 1)) ((log_lr_1, ), (log_lr_2, )) = observations self.assertGreater(log_lr_1, log_lr_2)
def test_converts_history_to_observations_with_controls(self): history = trax_history.History() self._append_metrics(history, ("train", "training/learning_rate"), [1e-3, 1e-4]) observations = online_tune.history_to_observations( history, metrics=(), observation_range=(0, 5), control_configs=(("learning_rate", None, (1e-9, 10.0), False), ), ) self.assertEqual(observations.shape, (2, 1)) ((log_lr_1, ), (log_lr_2, )) = observations self.assertGreater(log_lr_1, log_lr_2)
def test_clips_updated_control_without_flipping(self): config = ("learning_rate", None, (1e-9, 10.0), False) history = trax_history.History() self._append_metrics(history, online_tune.control_metric("learning_rate"), [7.0]) new_control = online_tune.update_control( control_config=config, action=2, history=history, action_multipliers=(0.5, 1.0, 2.0), ) np.testing.assert_almost_equal(new_control, 10.0)
def test_changes_lr_when_there_are_some_metrics(self): history = trax_history.History() history.append('eval', 'metrics/accuracy', step=0, value=0.8) history.append(*online_tune.control_metric('learning_rate'), step=0, value=1e-4) schedule = self._make_schedule( history, control_configs=(('learning_rate', 1e-3, (1e-9, 1.0), False), ), observation_metrics=(('eval', 'metrics/accuracy'), ), action_multipliers=(0.5, 2.0), ) new_lr = schedule(123)['learning_rate'] self.assertTrue(np.allclose(new_lr, 5e-5) or np.allclose(new_lr, 2e-4))
def load_trainer_state(output_dir): """Returns a TrainerState instance loaded from the given `output_dir`.""" weights_file = os.path.join(output_dir, 'model.pkl') if not gfile.exists(weights_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) pkl_module = utils.get_pickle_module() with gfile.GFile(weights_file, 'rb') as f: (opt_state, step, history, model_state) = pkl_module.load(f) log('Model loaded from %s at step %d' % (weights_file, step)) logging.debug('From loaded model : history = %s', history) return TrainerState(step=step, opt_state=OptState(*opt_state), history=history, model_state=model_state)
def test_works_with_serialized_policy(self): history = trax_history.History() history.append('eval', 'metrics/accuracy', step=0, value=0.8) history.append(*online_tune.control_metric('learning_rate'), step=0, value=1e-4) schedule = self._make_schedule( history, control_configs=(('learning_rate', 1e-3, (1e-9, 1.0), False), ), observation_metrics=(('eval', 'metrics/accuracy'), ), action_multipliers=(0.5, 2.0), vocab_size=16, ) new_lr = schedule(123)['learning_rate'] self.assertTrue(np.allclose(new_lr, 5e-5) or np.allclose(new_lr, 2e-4))
def load_trainer_state(output_dir, weights_file=None): """Returns a TrainerState instance loaded from the given `output_dir`.""" if weights_file is None: weights_file = os.path.join(output_dir, 'model.pkl') if not tf.io.gfile.exists(weights_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) elif not tf.io.gfile.exists(weights_file): raise ValueError('File not found: %s' % weights_file) with tf.io.gfile.GFile(weights_file, 'rb') as f: trainer_state_dict = pickle.load(f) trainer_state = trainer_state_from_dict(trainer_state_dict) log('Model loaded from %s at step %d' % (weights_file, trainer_state.step)) logging.debug('From loaded model : history = %s', trainer_state.history) return trainer_state
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state=None, policy_and_value_state=None): """Maybe restore the optimization state from the checkpoint dir. Optimization state includes parameters and optimizer slots. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_opt_state: Default optimization state, returned if model isn't found. policy_and_value_state: state of the policy and value network. Returns: tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the epoch from which we restored the optimization state, 0 if no checkpoint was found, and opt_step is the total optimization step (sum of all optimization steps made up to the current epoch). """ pkl_module = utils.get_pickle_module() epoch = 0 total_opt_step = 0 history = trax_history.History() for model_file in get_policy_model_files(output_dir): logging.info('Trying to restore model from %s', model_file) try: with tf.io.gfile.GFile(model_file, 'rb') as f: (policy_and_value_opt_state, policy_and_value_state, total_opt_step, history) = pkl_module.load(f) epoch = get_epoch_from_policy_model_file(model_file) break except EOFError as e: logging.error('Unable to load model from: %s with %s', model_file, e) # Try an older version. continue return ( policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step, history, )
def test_works_with_multiple_controls(self): history = trax_history.History() history.append('eval', 'metrics/accuracy', step=0, value=0.8) history.append(*online_tune.control_metric('learning_rate'), step=0, value=1e-4) history.append(*online_tune.control_metric('weight_decay_rate'), step=0, value=1e-5) schedule = self._make_schedule( history, observation_metrics=(('eval', 'metrics/accuracy'), ), control_configs=( ('learning_rate', 1e-3, (1e-9, 1.0), False), ('weight_decay_rate', 1e-5, (1e-9, 1.0), False), ), action_multipliers=(1.0, ), ) new_controls = schedule(123) self.assertIn('learning_rate', new_controls) self.assertIn('weight_decay_rate', new_controls)
def test_retrieves_historical_metric_values(self): history = trax_history.History() self._append_metrics(history, ('train', 'accuracy'), [0.1, 0.73]) metric_values = online_tune.historical_metric_values( history, metric=('train', 'accuracy'), observation_range=(0, 5)) np.testing.assert_array_equal(metric_values, [0.1, 0.73])
def test_schedule_from_lr_function(self): history = trax_history.History() schedule = lr_schedules.constant(history, 0.1) value = schedule(10) self.assertEqual(value['learning_rate'], 0.1)
def test_retrieves_historical_metric_values(self): history = trax_history.History() self._append_metrics(history, ("train", "accuracy"), [0.1, 0.73]) metric_values = online_tune.historical_metric_values( history, metric=("train", "accuracy")) np.testing.assert_array_equal(metric_values, [0.1, 0.73])
def test_clips_historical_metric_values(self): history = trax_history.History() self._append_metrics(history, ('train', 'loss'), [-10, 10]) metric_values = online_tune.historical_metric_values( history, metric=('train', 'loss'), observation_range=(-1, 1)) np.testing.assert_array_equal(metric_values, [-1, 1])