def test_export_best_ckpt(self, distribution): config = cfg.ExperimentConfig( trainer=cfg.TrainerConfig( best_checkpoint_export_subdir='best_ckpt', best_checkpoint_eval_metric='acc', optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } }))) model_dir = self.get_temp_dir() task = mock_task.MockTask(config.task, logging_dir=model_dir) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) trainer = trainer_lib.Trainer( config, task, model=task.build_model(), checkpoint_exporter=ckpt_exporter) trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) self.assertTrue( tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def create_test_trainer(self, config, model_dir=None): task = mock_task.MockTask(config.task, logging_dir=model_dir) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) trainer = trainer_lib.Trainer( config, task, model=task.build_model(), optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime), checkpoint_exporter=ckpt_exporter) return trainer
def run_experiment(distribution_strategy: tf.distribute.Strategy, task: base_task.Task, mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True) \ -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A Task instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. Returns: A 2-tuple of (model, eval_logs). model: `tf.keras.Model` instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}. """ with distribution_strategy.scope(): logging.info('Running progressive trainer.') trainer = prog_trainer_lib.ProgressiveTrainer( params, task, ckpt_dir=model_dir, train='train' in mode, evaluate=('eval' in mode) or run_post_eval, checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter( params, model_dir)) if trainer.checkpoint: checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=trainer.global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize) else: checkpoint_manager = None controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, global_step=trainer.global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if trainer.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return trainer.model, {}