示例#1
0
 def create_test_trainer(self, config, model_dir=None):
   task = mock_task.MockTask(config.task, logging_dir=model_dir)
   ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
   trainer = trainer_lib.Trainer(
       config,
       task,
       model=task.build_model(),
       optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime),
       checkpoint_exporter=ckpt_exporter)
   return trainer
  def test_recovery(self):
    config = cfg.ExperimentConfig(
        trainer=cfg.TrainerConfig(
            loss_upper_bound=0.5,
            recovery_max_trials=2,
            optimizer_config=cfg.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd'
                },
                'learning_rate': {
                    'type': 'constant'
                }
            })))
    model_dir = self.get_temp_dir()
    trainer = self.create_test_trainer(config, model_dir=model_dir)
    checkpoint_manager = tf.train.CheckpointManager(
        trainer.checkpoint, self.get_temp_dir(), max_to_keep=2)
    checkpoint_manager.save()
    trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
    before_weights = trainer.model.get_weights()
    _ = trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
    # The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
    after_weights = trainer.model.get_weights()
    for left, right in zip(before_weights, after_weights):
      self.assertAllEqual(left, right)

    # Let's the loss be NaN and max_trials = 0 to see RuntimeError.
    config = cfg.ExperimentConfig(
        trainer=cfg.TrainerConfig(
            recovery_max_trials=0,
            optimizer_config=cfg.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd'
                },
                'learning_rate': {
                    'type': 'constant'
                }
            })))
    task = mock_task.MockTask(config.task, logging_dir=model_dir)
    def build_losses(labels, model_outputs, aux_losses=None):
      del labels, model_outputs
      return tf.constant([np.nan], tf.float32) + aux_losses
    task.build_losses = build_losses
    trainer = trainer_lib.Trainer(
        config,
        task,
        model=task.build_model(),
        optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime))
    trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
    with self.assertRaises(RuntimeError):
      _ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
示例#3
0
def create_trainer(params: config_definitions.ExperimentConfig,
                   task: base_task.Task,
                   train: bool,
                   evaluate: bool,
                   checkpoint_exporter: Any = None) -> base_trainer.Trainer:
  """Create trainer."""
  logging.info('Running default trainer.')
  model = task.build_model()
  optimizer = base_trainer.create_optimizer(params.trainer, params.runtime)
  trainer = base_trainer.Trainer(
      params,
      task,
      model=model,
      optimizer=optimizer,
      train=train,
      evaluate=evaluate,
      checkpoint_exporter=checkpoint_exporter)
  return trainer