Beispiel #1
0
 def test_export_best_ckpt(self, distribution):
   config = cfg.ExperimentConfig(
       trainer=cfg.TrainerConfig(
           best_checkpoint_export_subdir='best_ckpt',
           best_checkpoint_eval_metric='acc',
           optimizer_config=cfg.OptimizationConfig({
               'optimizer': {
                   'type': 'sgd'
               },
               'learning_rate': {
                   'type': 'constant'
               }
           })))
   model_dir = self.get_temp_dir()
   task = mock_task.MockTask(config.task, logging_dir=model_dir)
   ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
   trainer = trainer_lib.Trainer(
       config,
       task,
       model=task.build_model(),
       checkpoint_exporter=ckpt_exporter)
   trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
   trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
   self.assertTrue(
       tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
Beispiel #2
0
 def create_test_trainer(self, config, model_dir=None):
   task = mock_task.MockTask(config.task, logging_dir=model_dir)
   ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
   trainer = trainer_lib.Trainer(
       config,
       task,
       model=task.build_model(),
       optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime),
       checkpoint_exporter=ckpt_exporter)
   return trainer
Beispiel #3
0
def run_experiment(distribution_strategy: tf.distribute.Strategy,
                   task: base_task.Task,
                   mode: str,
                   params: config_definitions.ExperimentConfig,
                   model_dir: str,
                   run_post_eval: bool = False,
                   save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.

  Returns:
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
  """

    with distribution_strategy.scope():
        logging.info('Running progressive trainer.')
        trainer = prog_trainer_lib.ProgressiveTrainer(
            params,
            task,
            ckpt_dir=model_dir,
            train='train' in mode,
            evaluate=('eval' in mode) or run_post_eval,
            checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
                params, model_dir))

    if trainer.checkpoint:
        checkpoint_manager = tf.train.CheckpointManager(
            trainer.checkpoint,
            directory=model_dir,
            max_to_keep=params.trainer.max_to_keep,
            step_counter=trainer.global_step,
            checkpoint_interval=params.trainer.checkpoint_interval,
            init_fn=trainer.initialize)
    else:
        checkpoint_manager = None

    controller = orbit.Controller(
        strategy=distribution_strategy,
        trainer=trainer if 'train' in mode else None,
        evaluator=trainer,
        global_step=trainer.global_step,
        steps_per_loop=params.trainer.steps_per_loop,
        checkpoint_manager=checkpoint_manager,
        summary_dir=os.path.join(model_dir, 'train') if
        (save_summary) else None,
        eval_summary_dir=os.path.join(model_dir, 'validation') if
        (save_summary) else None,
        summary_interval=params.trainer.summary_interval if
        (save_summary) else None)

    logging.info('Starts to execute mode: %s', mode)
    with distribution_strategy.scope():
        if mode == 'train':
            controller.train(steps=params.trainer.train_steps)
        elif mode == 'train_and_eval':
            controller.train_and_evaluate(
                train_steps=params.trainer.train_steps,
                eval_steps=params.trainer.validation_steps,
                eval_interval=params.trainer.validation_interval)
        elif mode == 'eval':
            controller.evaluate(steps=params.trainer.validation_steps)
        elif mode == 'continuous_eval':

            def timeout_fn():
                if trainer.global_step.numpy() >= params.trainer.train_steps:
                    return True
                return False

            controller.evaluate_continuously(
                steps=params.trainer.validation_steps,
                timeout=params.trainer.continuous_eval_timeout,
                timeout_fn=timeout_fn)
        else:
            raise NotImplementedError('The mode is not implemented: %s' % mode)

    if run_post_eval:
        with distribution_strategy.scope():
            return trainer.model, trainer.evaluate(
                tf.convert_to_tensor(params.trainer.validation_steps))
    else:
        return trainer.model, {}