def create_test_trainer(self, config, model_dir=None): task = mock_task.MockTask(config.task, logging_dir=model_dir) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) trainer = trainer_lib.Trainer( config, task, model=task.build_model(), optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime), checkpoint_exporter=ckpt_exporter) return trainer
def test_recovery(self): config = cfg.ExperimentConfig( trainer=cfg.TrainerConfig( loss_upper_bound=0.5, recovery_max_trials=2, optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } }))) model_dir = self.get_temp_dir() trainer = self.create_test_trainer(config, model_dir=model_dir) checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, self.get_temp_dir(), max_to_keep=2) checkpoint_manager.save() trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager) before_weights = trainer.model.get_weights() _ = trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) # The training loss is 1.0 and upper_bound is 0.5, so the recover happens. after_weights = trainer.model.get_weights() for left, right in zip(before_weights, after_weights): self.assertAllEqual(left, right) # Let's the loss be NaN and max_trials = 0 to see RuntimeError. config = cfg.ExperimentConfig( trainer=cfg.TrainerConfig( recovery_max_trials=0, optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } }))) task = mock_task.MockTask(config.task, logging_dir=model_dir) def build_losses(labels, model_outputs, aux_losses=None): del labels, model_outputs return tf.constant([np.nan], tf.float32) + aux_losses task.build_losses = build_losses trainer = trainer_lib.Trainer( config, task, model=task.build_model(), optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime)) trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager) with self.assertRaises(RuntimeError): _ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
def create_trainer(params: config_definitions.ExperimentConfig, task: base_task.Task, train: bool, evaluate: bool, checkpoint_exporter: Any = None) -> base_trainer.Trainer: """Create trainer.""" logging.info('Running default trainer.') model = task.build_model() optimizer = base_trainer.create_optimizer(params.trainer, params.runtime) trainer = base_trainer.Trainer( params, task, model=model, optimizer=optimizer, train=train, evaluate=evaluate, checkpoint_exporter=checkpoint_exporter) return trainer