Example #1
0
 def create_test_trainer(self, distribution):
     trainer = trainer_lib.ProgressiveTrainer(self._config,
                                              prog_task=TestPolicy(
                                                  distribution,
                                                  self._config.task),
                                              ckpt_dir=self.get_temp_dir())
     return trainer
Example #2
0
 def create_test_trainer(self, distribution, model_dir, change_train_dataset):
   trainer = trainer_lib.ProgressiveTrainer(
       self._config,
       prog_task=TestPolicy(
           distribution, self._config.task, change_train_dataset),
       ckpt_dir=model_dir)
   return trainer
Example #3
0
  def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
    config = cfg.ExperimentConfig(
        task=cfg.TaskConfig(
            model=bert.PretrainerConfig()),
        runtime=cfg.RuntimeConfig(
            mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
        trainer=trainer_lib.ProgressiveTrainerConfig(
            export_checkpoint=True,
            export_checkpoint_interval=1,
            export_only_final_stage_ckpt=False))
    task = TestPolicy(None, config.task)
    trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
    if mixed_precision_dtype != 'float16':
      self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
    elif mixed_precision_dtype == 'float16' and loss_scale is None:
      self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)

    metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
    self.assertIn('training_loss', metrics)
Example #4
0
def run_experiment(distribution_strategy: tf.distribute.Strategy,
                   task: base_task.Task,
                   mode: str,
                   params: config_definitions.ExperimentConfig,
                   model_dir: str,
                   run_post_eval: bool = False,
                   save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        logging.info('Running progressive trainer.')
        trainer = prog_trainer_lib.ProgressiveTrainer(
            params,
            task,
            ckpt_dir=model_dir,
            train='train' in mode,
            evaluate=('eval' in mode) or run_post_eval,
            checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
                params, model_dir))

    if trainer.checkpoint:
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
    else:
        checkpoint_manager = None

    controller = orbit.Controller(
        strategy=distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train') if
        (save_summary) else None,
        eval_summary_dir=os.path.join(model_dir, 'validation') if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None)

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if trainer.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}