def main(_): logging.info('Parsing config files...') gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = get_exp_config() # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) with distribution_strategy.scope(): task = distillation.BertDistillationTask( strategy=distribution_strategy, progressive=params.trainer.progressive, optimizer_config=params.trainer.optimizer_config, task_config=params.task) train_lib.run_experiment(distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=FLAGS.model_dir)
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = train_utils.parse_configuration(FLAGS) model_dir = FLAGS.model_dir if 'train' in FLAGS.mode: # Pure eval modes do not output yaml files. Otherwise continuous eval job # may race against the train job for writing the same file. train_utils.serialize_config(params, model_dir) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu, **params.runtime.model_parallelism()) with distribution_strategy.scope(): task = task_factory.get_task(params.task, logging_dir=model_dir) train_lib.run_experiment(distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = get_exp_config() distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) with distribution_strategy.scope(): task = masked_lm.ProgressiveMaskedLM( strategy=distribution_strategy, progressive_config=params.trainer.progressive, optimizer_config=params.trainer.optimizer_config, train_data_config=params.task.train_data, small_train_data_config=params.task.small_train_data, task_config=params.task) train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=FLAGS.model_dir)
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval): model_dir = self.get_temp_dir() experiment_config = cfg.ExperimentConfig( trainer=prog_trainer_lib.ProgressiveTrainerConfig(), task=ProgTaskConfig()) experiment_config = params_dict.override_params_dict(experiment_config, self._test_config, is_strict=False) with distribution_strategy.scope(): task = task_factory.get_task(experiment_config.task, logging_dir=model_dir) _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=flag_mode, params=experiment_config, model_dir=model_dir, run_post_eval=run_post_eval) if run_post_eval: self.assertNotEmpty(logs) else: self.assertEmpty(logs) if flag_mode == 'eval': return self.assertNotEmpty( tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) # Tests continuous evaluation. _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='continuous_eval', params=experiment_config, model_dir=model_dir, run_post_eval=run_post_eval) print(logs)