예제 #1
0
파일: trainer_lib.py 프로젝트: e7dud7e/trax
  def reset(self, output_dir, init_checkpoint=None):
    """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
      init_checkpoint: Initial checkpoint to use (default $output_dir/model.pkl)
    """
    self.close()
    self._output_dir = output_dir
    if output_dir is not None:
      tf.io.gfile.makedirs(output_dir)
    else:
      assert not self._should_save_checkpoints
      assert not self._should_write_summaries

    # Create summary writers and history.
    if self._should_write_summaries:
      self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'),
                                              enable=self._is_chief)
      self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'),
                                             enable=self._is_chief)

    # Reset the train and eval streams.
    self._train_stream = _repeat_stream(self._inputs.train_stream,
                                        self._n_devices)
    # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
    #   set by adding a padding and stopping the stream when too large.
    self._eval_stream = _repeat_stream(
        self._inputs.eval_stream, self._n_devices)
    self._train_eval_stream = _repeat_stream(
        self._inputs.train_eval_stream, self._n_devices)

    # Restore the training state.
    if output_dir is not None:
      state = load_trainer_state(output_dir, init_checkpoint)
    else:
      state = TrainerState(step=None, opt_state=None,
                           history=trax_history.History(), model_state=None)
    self._step = state.step or 0
    history = state.history
    self._lr_fn = self._lr_schedule(history)
    self._history = history
    if state.opt_state:
      opt_state = state.opt_state
      model_state = state.model_state
    else:
      opt_state, model_state = self._new_opt_state_and_model_state()
      model_state = self._for_n_devices(model_state)
    self._opt_state = OptState(*self._for_n_devices(opt_state))
    self._model_state = model_state
    if not state.opt_state and self._should_save_checkpoints:
      self.save_state(keep=False)

    self.update_nontrainable_params()
예제 #2
0
 def test_returns_start_lr_when_there_are_no_metrics(self):
     history = trax_history.History()
     start_lr = 1e-3
     schedule = self._make_schedule(
         history,
         control_configs=(('learning_rate', start_lr, (1e-9, 1.0),
                           False), ),
     )
     self.assertEqual(schedule(0)['learning_rate'], start_lr)
예제 #3
0
 def test_clips_observations(self):
     history = trax_history.History()
     self._append_metrics(history, ('eval', 'loss'), [-10, 10])
     observations = online_tune.history_to_observations(
         history,
         metrics=(('eval', 'loss'), ),
         observation_range=(-2, 2),
         include_lr=False,
     )
     np.testing.assert_array_equal(observations, [[-2], [2]])
예제 #4
0
 def test_clips_observations(self):
     history = trax_history.History()
     self._append_metrics(history, ("eval", "loss"), [-10, 10])
     observations = online_tune.history_to_observations(
         history,
         metrics=(("eval", "loss"), ),
         observation_range=(-2, 2),
         control_configs=None,
     )
     np.testing.assert_array_equal(observations, [[-2], [2]])
예제 #5
0
 def test_clips_new_learning_rate(self):
     history = trax_history.History()
     self._append_metrics(history, online_tune.LEARNING_RATE_METRIC, [1e-3])
     new_lr = online_tune.new_learning_rate(
         action=0,
         history=history,
         action_multipliers=(4.0, 1.0, 0.25),
         max_lr=3e-3,
     )
     np.testing.assert_almost_equal(new_lr, 3e-3)
예제 #6
0
 def test_calculates_new_learning_rate(self):
     history = trax_history.History()
     self._append_metrics(history, online_tune.LEARNING_RATE_METRIC,
                          [1e-2, 1e-3])
     new_lr = online_tune.new_learning_rate(
         action=2,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
         max_lr=1.0,
     )
     np.testing.assert_almost_equal(new_lr, 2e-3)
예제 #7
0
 def test_converts_history_to_observations_without_controls(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "loss"), [1.0, 0.07])
     self._append_metrics(history, ("eval", "accuracy"), [0.12, 0.68])
     observations = online_tune.history_to_observations(
         history,
         metrics=(("eval", "accuracy"), ("train", "loss")),
         observation_range=(-1, 1),
         control_configs=None,
     )
     np.testing.assert_array_almost_equal(observations,
                                          [[0.12, 1.0], [0.68, 0.07]])
예제 #8
0
 def test_clips_updated_control_with_flipping(self):
     config = ("momentum", None, (0.5, 0.99), True)
     history = trax_history.History()
     self._append_metrics(history, online_tune.control_metric("momentum"),
                          [0.985])
     new_control = online_tune.update_control(
         control_config=config,
         action=0,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
     )
     np.testing.assert_almost_equal(new_control, 0.99)
예제 #9
0
 def test_converts_history_to_observations_without_learning_rate(self):
     history = trax_history.History()
     self._append_metrics(history, ('train', 'loss'), [3.0, 1.07])
     self._append_metrics(history, ('eval', 'accuracy'), [0.12, 0.68])
     observations = online_tune.history_to_observations(
         history,
         metrics=(('eval', 'accuracy'), ('train', 'loss')),
         observation_range=(0, 5),
         include_lr=False,
     )
     np.testing.assert_array_equal(observations,
                                   [[0.12, 3.0], [0.68, 1.07]])
예제 #10
0
 def test_converts_history_to_observations_with_learning_rate(self):
     history = trax_history.History()
     self._append_metrics(history, ('train', 'training/learning_rate'),
                          [1e-3, 1e-4])
     observations = online_tune.history_to_observations(
         history,
         metrics=(),
         observation_range=(0, 5),
         include_lr=True,
     )
     self.assertEqual(observations.shape, (2, 1))
     ((log_lr_1, ), (log_lr_2, )) = observations
     self.assertGreater(log_lr_1, log_lr_2)
예제 #11
0
 def test_converts_history_to_observations_with_controls(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "training/learning_rate"),
                          [1e-3, 1e-4])
     observations = online_tune.history_to_observations(
         history,
         metrics=(),
         observation_range=(0, 5),
         control_configs=(("learning_rate", None, (1e-9, 10.0), False), ),
     )
     self.assertEqual(observations.shape, (2, 1))
     ((log_lr_1, ), (log_lr_2, )) = observations
     self.assertGreater(log_lr_1, log_lr_2)
예제 #12
0
 def test_clips_updated_control_without_flipping(self):
     config = ("learning_rate", None, (1e-9, 10.0), False)
     history = trax_history.History()
     self._append_metrics(history,
                          online_tune.control_metric("learning_rate"),
                          [7.0])
     new_control = online_tune.update_control(
         control_config=config,
         action=2,
         history=history,
         action_multipliers=(0.5, 1.0, 2.0),
     )
     np.testing.assert_almost_equal(new_control, 10.0)
예제 #13
0
 def test_changes_lr_when_there_are_some_metrics(self):
     history = trax_history.History()
     history.append('eval', 'metrics/accuracy', step=0, value=0.8)
     history.append(*online_tune.control_metric('learning_rate'),
                    step=0,
                    value=1e-4)
     schedule = self._make_schedule(
         history,
         control_configs=(('learning_rate', 1e-3, (1e-9, 1.0), False), ),
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(0.5, 2.0),
     )
     new_lr = schedule(123)['learning_rate']
     self.assertTrue(np.allclose(new_lr, 5e-5) or np.allclose(new_lr, 2e-4))
def load_trainer_state(output_dir):
  """Returns a TrainerState instance loaded from the given `output_dir`."""
  weights_file = os.path.join(output_dir, 'model.pkl')
  if not gfile.exists(weights_file):
    return TrainerState(step=None, opt_state=None,
                        history=trax_history.History(), model_state=None)

  pkl_module = utils.get_pickle_module()
  with gfile.GFile(weights_file, 'rb') as f:
    (opt_state, step, history, model_state) = pkl_module.load(f)
  log('Model loaded from %s at step %d' % (weights_file, step))
  logging.debug('From loaded model : history = %s', history)
  return TrainerState(step=step, opt_state=OptState(*opt_state),
                      history=history, model_state=model_state)
예제 #15
0
 def test_works_with_serialized_policy(self):
     history = trax_history.History()
     history.append('eval', 'metrics/accuracy', step=0, value=0.8)
     history.append(*online_tune.control_metric('learning_rate'),
                    step=0,
                    value=1e-4)
     schedule = self._make_schedule(
         history,
         control_configs=(('learning_rate', 1e-3, (1e-9, 1.0), False), ),
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(0.5, 2.0),
         vocab_size=16,
     )
     new_lr = schedule(123)['learning_rate']
     self.assertTrue(np.allclose(new_lr, 5e-5) or np.allclose(new_lr, 2e-4))
예제 #16
0
def load_trainer_state(output_dir, weights_file=None):
  """Returns a TrainerState instance loaded from the given `output_dir`."""
  if weights_file is None:
    weights_file = os.path.join(output_dir, 'model.pkl')
    if not tf.io.gfile.exists(weights_file):
      return TrainerState(step=None, opt_state=None,
                          history=trax_history.History(), model_state=None)
  elif not tf.io.gfile.exists(weights_file):
    raise ValueError('File not found: %s' % weights_file)

  with tf.io.gfile.GFile(weights_file, 'rb') as f:
    trainer_state_dict = pickle.load(f)
  trainer_state = trainer_state_from_dict(trainer_state_dict)
  log('Model loaded from %s at step %d' % (weights_file, trainer_state.step))
  logging.debug('From loaded model : history = %s', trainer_state.history)
  return trainer_state
예제 #17
0
def maybe_restore_opt_state(output_dir,
                            policy_and_value_opt_state=None,
                            policy_and_value_state=None):
    """Maybe restore the optimization state from the checkpoint dir.

  Optimization state includes parameters and optimizer slots.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_opt_state: Default optimization state, returned if model
      isn't found.
    policy_and_value_state: state of the policy and value network.

  Returns:
    tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the
    epoch from which we restored the optimization state, 0 if no checkpoint was
    found, and opt_step is the total optimization step (sum of all optimization
    steps made up to the current epoch).
  """
    pkl_module = utils.get_pickle_module()
    epoch = 0
    total_opt_step = 0
    history = trax_history.History()
    for model_file in get_policy_model_files(output_dir):
        logging.info('Trying to restore model from %s', model_file)
        try:
            with tf.io.gfile.GFile(model_file, 'rb') as f:
                (policy_and_value_opt_state, policy_and_value_state,
                 total_opt_step, history) = pkl_module.load(f)
            epoch = get_epoch_from_policy_model_file(model_file)
            break
        except EOFError as e:
            logging.error('Unable to load model from: %s with %s', model_file,
                          e)
            # Try an older version.
            continue
    return (
        policy_and_value_opt_state,
        policy_and_value_state,
        epoch,
        total_opt_step,
        history,
    )
예제 #18
0
 def test_works_with_multiple_controls(self):
     history = trax_history.History()
     history.append('eval', 'metrics/accuracy', step=0, value=0.8)
     history.append(*online_tune.control_metric('learning_rate'),
                    step=0,
                    value=1e-4)
     history.append(*online_tune.control_metric('weight_decay_rate'),
                    step=0,
                    value=1e-5)
     schedule = self._make_schedule(
         history,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         control_configs=(
             ('learning_rate', 1e-3, (1e-9, 1.0), False),
             ('weight_decay_rate', 1e-5, (1e-9, 1.0), False),
         ),
         action_multipliers=(1.0, ),
     )
     new_controls = schedule(123)
     self.assertIn('learning_rate', new_controls)
     self.assertIn('weight_decay_rate', new_controls)
예제 #19
0
 def test_retrieves_historical_metric_values(self):
     history = trax_history.History()
     self._append_metrics(history, ('train', 'accuracy'), [0.1, 0.73])
     metric_values = online_tune.historical_metric_values(
         history, metric=('train', 'accuracy'), observation_range=(0, 5))
     np.testing.assert_array_equal(metric_values, [0.1, 0.73])
예제 #20
0
 def test_schedule_from_lr_function(self):
     history = trax_history.History()
     schedule = lr_schedules.constant(history, 0.1)
     value = schedule(10)
     self.assertEqual(value['learning_rate'], 0.1)
예제 #21
0
 def test_retrieves_historical_metric_values(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "accuracy"), [0.1, 0.73])
     metric_values = online_tune.historical_metric_values(
         history, metric=("train", "accuracy"))
     np.testing.assert_array_equal(metric_values, [0.1, 0.73])
예제 #22
0
 def test_clips_historical_metric_values(self):
     history = trax_history.History()
     self._append_metrics(history, ('train', 'loss'), [-10, 10])
     metric_values = online_tune.historical_metric_values(
         history, metric=('train', 'loss'), observation_range=(-1, 1))
     np.testing.assert_array_equal(metric_values, [-1, 1])