def run_continuous_finetune( mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, ) -> Mapping[str, Any]: """Run modes with continuous training. Currently only supports continuous_train_and_eval. Args: mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a checkpoint directory. Once a new checkpoint is discovered, loads the checkpoint, finetune the model by training it (probably on another dataset or with another task), then evaluate the finetuned model. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. Returns: eval logs: returns eval metrics logs when run_post_eval is set to True, othewise, returns {}. """ assert mode == 'continuous_train_and_eval', ( 'Only continuous_train_and_eval is supported by continuous_finetune. ' 'Got mode: {}'.format(mode)) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, params.runtime.loss_scale) distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) retry_times = 0 while not tf.io.gfile.isdir(params.task.init_checkpoint): # Wait for the init_checkpoint directory to be created. if retry_times >= 60: raise ValueError( 'ExperimentConfig.task.init_checkpoint must be a directory for ' 'continuous_train_and_eval mode.') retry_times += 1 time.sleep(60) summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'eval')) for pretrain_ckpt in tf.train.checkpoints_iterator( checkpoint_dir=params.task.init_checkpoint, min_interval_secs=10, timeout=params.trainer.continuous_eval_timeout): with distribution_strategy.scope(): global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt) if params.trainer.best_checkpoint_export_subdir: best_ckpt_subdir = '{}_{}'.format( params.trainer.best_checkpoint_export_subdir, global_step) params_replaced = params.replace( task={'init_checkpoint': pretrain_ckpt}, trainer={'best_checkpoint_export_subdir': best_ckpt_subdir}) else: params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt}) params_replaced.lock() logging.info('Running finetuning with params: %s', params_replaced) with distribution_strategy.scope(): task = task_factory.get_task(params_replaced.task, logging_dir=model_dir) _, eval_metrics = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='train_and_eval', # replace params.task.init_checkpoint to make sure that we load # exactly this pretrain checkpoint. params=params_replaced, model_dir=model_dir, run_post_eval=True, save_summary=False) logging.info('Evaluation finished. Pretrain global_step: %d', global_step) train_utils.write_json_summary(model_dir, global_step, eval_metrics) if not os.path.basename(model_dir): # if model_dir.endswith('/') summary_grp = os.path.dirname(model_dir) + '_' + task.__class__.__name__ else: summary_grp = os.path.basename(model_dir) + '_' + task.__class__.__name__ summaries = {} for name, value in eval_metrics.items(): summaries[summary_grp + '/' + name] = value train_utils.write_summary(summary_writer, global_step, summaries) train_utils.remove_ckpts(model_dir) if run_post_eval: return eval_metrics return {}
def run_continuous_finetune( mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, pretrain_steps: Optional[int] = None, ) -> Mapping[str, Any]: """Run modes with continuous training. Currently only supports continuous_train_and_eval. Args: mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a checkpoint directory. Once a new checkpoint is discovered, loads the checkpoint, finetune the model by training it (probably on another dataset or with another task), then evaluate the finetuned model. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. pretrain_steps: Optional, the number of total training steps for the pretraining job. Returns: eval logs: returns eval metrics logs when run_post_eval is set to True, othewise, returns {}. """ assert mode == 'continuous_train_and_eval', ( 'Only continuous_train_and_eval is supported by continuous_finetune. ' 'Got mode: {}'.format(mode)) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype, params.runtime.loss_scale) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) retry_times = 0 while not tf.io.gfile.isdir(params.task.init_checkpoint): # Wait for the init_checkpoint directory to be created. if retry_times >= 60: raise ValueError( 'ExperimentConfig.task.init_checkpoint must be a directory for ' 'continuous_train_and_eval mode.') retry_times += 1 time.sleep(60) summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'eval')) global_step = 0 def timeout_fn(): if pretrain_steps and global_step < pretrain_steps: # Keeps waiting for another timeout period. logging.info( 'Continue waiting for new checkpoint as current pretrain ' 'global_step=%d and target is %d.', global_step, pretrain_steps) return False # Quits the loop. return True for pretrain_ckpt in tf.train.checkpoints_iterator( checkpoint_dir=params.task.init_checkpoint, min_interval_secs=10, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn): with distribution_strategy.scope(): global_step = train_utils.read_global_step_from_checkpoint( pretrain_ckpt) # Replaces params.task.init_checkpoint to make sure that we load # exactly this pretrain checkpoint. if params.trainer.best_checkpoint_export_subdir: best_ckpt_subdir = '{}_{}'.format( params.trainer.best_checkpoint_export_subdir, global_step) params_replaced = params.replace( task={'init_checkpoint': pretrain_ckpt}, trainer={'best_checkpoint_export_subdir': best_ckpt_subdir}) else: params_replaced = params.replace( task={'init_checkpoint': pretrain_ckpt}) params_replaced.lock() logging.info('Running finetuning with params: %s', params_replaced) with distribution_strategy.scope(): if isinstance(params, configs.MultiEvalExperimentConfig): task = task_factory.get_task(params_replaced.task) eval_tasks = multitask.MultiTask.from_config( params_replaced.eval_tasks) (_, eval_metrics ) = multitask_train_lib.run_experiment_wtih_multitask_eval( distribution_strategy=distribution_strategy, train_task=task, eval_tasks=eval_tasks, mode='train_and_eval', params=params_replaced, model_dir=model_dir, run_post_eval=True, save_summary=False) else: task = task_factory.get_task(params_replaced.task, logging_dir=model_dir) _, eval_metrics = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='train_and_eval', params=params_replaced, model_dir=model_dir, run_post_eval=True, save_summary=False) logging.info('Evaluation finished. Pretrain global_step: %d', global_step) train_utils.write_json_summary(model_dir, global_step, eval_metrics) if not os.path.basename(model_dir): # if model_dir.endswith('/') summary_grp = os.path.dirname(model_dir) + '_' + task.name else: summary_grp = os.path.basename(model_dir) + '_' + task.name summaries = {} for name, value in _flatten_dict(eval_metrics).items(): summaries[summary_grp + '/' + '-'.join(name)] = value train_utils.write_summary(summary_writer, global_step, summaries) train_utils.remove_ckpts(model_dir) # In TF2, the resource life cycle is bound with the python object life # cycle. Force trigger python garbage collection here so those resources # can be deallocated in time, so it doesn't cause OOM when allocating new # objects. # TODO(b/169178664): Fix cycle reference in Keras model and revisit to see # if we need gc here. gc.collect() if run_post_eval: return eval_metrics return {}