def test_train_model_locally(self): """Tests training a model locally with one step.""" teacher_model = model_builder.build_bert_pretrainer( pretrainer_cfg=self.experiment_params.teacher_model, name='teacher') _ = teacher_model(teacher_model.inputs) student_model = model_builder.build_bert_pretrainer( pretrainer_cfg=self.experiment_params.student_model, name='student') _ = student_model(student_model.inputs) trainer = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer( teacher_model=teacher_model, student_model=student_model, strategy=self.strategy, experiment_params=self.experiment_params) # Rebuild dummy dataset since loading real dataset will cause timeout error. trainer.train_dataset = _dummy_dataset() trainer.eval_dataset = _dummy_dataset() train_dataset_iter = iter(trainer.train_dataset) eval_dataset_iter = iter(trainer.eval_dataset) trainer.train_loop_begin() trainer.train_step(train_dataset_iter) trainer.eval_step(eval_dataset_iter)
def main(_): # Set up experiment params and load the configs from file/files. experiment_params = params.EdgeTPUBERTCustomParams() experiment_params = utils.config_override(experiment_params, FLAGS) model_dir = utils.get_model_dir(experiment_params, FLAGS) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=experiment_params.runtime.distribution_strategy, all_reduce_alg=experiment_params.runtime.all_reduce_alg, num_gpus=experiment_params.runtime.num_gpus, tpu_address=experiment_params.runtime.tpu_address) with distribution_strategy.scope(): teacher_model = model_builder.build_bert_pretrainer( pretrainer_cfg=experiment_params.teacher_model, quantization_friendly=False, name='teacher') student_model = model_builder.build_bert_pretrainer( pretrainer_cfg=experiment_params.student_model, quantization_friendly=True, name='student') # Load model weights. teacher_ckpt_dir_or_file = experiment_params.teacher_model_init_checkpoint if not teacher_ckpt_dir_or_file: raise ValueError( '`teacher_model_init_checkpoint` is not specified.') utils.load_checkpoint(teacher_model, teacher_ckpt_dir_or_file) student_ckpt_dir_or_file = experiment_params.student_model_init_checkpoint if not student_ckpt_dir_or_file: # Makes sure the pretrainer variables are created. _ = student_model(student_model.inputs) logging.warn( 'No student checkpoint is provided, training might take ' 'much longer before converging.') else: utils.load_checkpoint(student_model, student_ckpt_dir_or_file) runner = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer( teacher_model=teacher_model, student_model=student_model, strategy=distribution_strategy, experiment_params=experiment_params, export_ckpt_path=model_dir) # Save checkpoint for preemption handling. # Checkpoint for downstreaming tasks are saved separately inside the # runner's train_loop_end() function. checkpoint = tf.train.Checkpoint( teacher_model=runner.teacher_model, student_model=runner.student_model, layer_wise_optimizer=runner.layer_wise_optimizer, e2e_optimizer=runner.e2e_optimizer, current_step=runner.current_step) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=5, step_counter=runner.current_step, checkpoint_interval=20000, init_fn=None) controller = orbit.Controller( trainer=runner, evaluator=runner, global_step=runner.current_step, strategy=distribution_strategy, steps_per_loop=experiment_params.orbit_config.steps_per_loop, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'eval'), checkpoint_manager=checkpoint_manager) if FLAGS.mode == 'train': controller.train(steps=experiment_params.orbit_config.total_steps) else: raise ValueError('Unsupported mode, only support `train`')