def get_exp_config(): return cfg.ExperimentConfig( task=cfg.TaskConfig(model=bert.PretrainerConfig()), trainer=trainer_lib.ProgressiveTrainerConfig( export_checkpoint=True, export_checkpoint_interval=1, export_only_final_stage_ckpt=False))
def get_exp_config(): """Get ExperimentConfig.""" params = cfg.ExperimentConfig( task=distillation.BertDistillationTaskConfig( train_data=pretrain_dataloader.BertPretrainDataConfig(), validation_data=pretrain_dataloader.BertPretrainDataConfig( is_training=False)), trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=distillation.BertDistillationProgressiveConfig(), optimizer_config=optimization_config, train_steps=740000, checkpoint_interval=20000)) return config_override(params, FLAGS)
def get_exp_config(): """Get ExperimentConfig.""" params = cfg.ExperimentConfig( task=masked_lm.MaskedLMConfig( train_data=pretrain_dataloader.BertPretrainDataConfig(), small_train_data=pretrain_dataloader.BertPretrainDataConfig(), validation_data=pretrain_dataloader.BertPretrainDataConfig( is_training=False)), trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=masked_lm.ProgStackingConfig(), optimizer_config=BertOptimizationConfig(), train_steps=1000000), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return utils.config_override(params, FLAGS)
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale): config = cfg.ExperimentConfig( task=cfg.TaskConfig(model=bert.PretrainerConfig()), runtime=cfg.RuntimeConfig( mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale), trainer=trainer_lib.ProgressiveTrainerConfig( export_checkpoint=True, export_checkpoint_interval=1, export_only_final_stage_ckpt=False)) task = TestPolicy(None, config.task) trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir()) if mixed_precision_dtype != 'float16': self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) elif mixed_precision_dtype == 'float16' and loss_scale is None: self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) self.assertIn('training_loss', metrics)
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval): model_dir = self.get_temp_dir() experiment_config = cfg.ExperimentConfig( trainer=prog_trainer_lib.ProgressiveTrainerConfig(), task=ProgTaskConfig()) experiment_config = params_dict.override_params_dict(experiment_config, self._test_config, is_strict=False) with distribution_strategy.scope(): task = task_factory.get_task(experiment_config.task, logging_dir=model_dir) _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=flag_mode, params=experiment_config, model_dir=model_dir, run_post_eval=run_post_eval) if run_post_eval: self.assertNotEmpty(logs) else: self.assertEmpty(logs) if flag_mode == 'eval': return self.assertNotEmpty( tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) # Tests continuous evaluation. _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='continuous_eval', params=experiment_config, model_dir=model_dir, run_post_eval=run_post_eval) print(logs)
def prepare_config(self, teacher_block_num, student_block_num, transfer_teacher_layers): # using small model for testing task_config = distillation.BertDistillationTaskConfig( teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig( type='mobilebert', mobilebert=encoders.MobileBertEncoderConfig( num_blocks=teacher_block_num)), cls_heads=[ bert.ClsHeadConfig( inner_dim=256, num_classes=2, dropout_rate=0.1, name='next_sentence') ], mlm_activation='gelu'), student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig( type='mobilebert', mobilebert=encoders.MobileBertEncoderConfig( num_blocks=student_block_num)), cls_heads=[ bert.ClsHeadConfig( inner_dim=256, num_classes=2, dropout_rate=0.1, name='next_sentence') ], mlm_activation='relu'), train_data=pretrain_dataloader.BertPretrainDataConfig( input_path='dummy', max_predictions_per_seq=76, seq_length=512, global_batch_size=10), validation_data=pretrain_dataloader.BertPretrainDataConfig( input_path='dummy', max_predictions_per_seq=76, seq_length=512, global_batch_size=10)) # set only 1 step for each stage progressive_config = distillation.BertDistillationProgressiveConfig() progressive_config.layer_wise_distill_config.transfer_teacher_layers = ( transfer_teacher_layers) progressive_config.layer_wise_distill_config.num_steps = 1 progressive_config.pretrain_distill_config.num_steps = 1 optimization_config = optimization.OptimizationConfig( optimizer=optimization.OptimizerConfig( type='lamb', lamb=optimization.LAMBConfig(weight_decay_rate=0.0001, exclude_from_weight_decay=[ 'LayerNorm', 'layer_norm', 'bias', 'no_norm' ])), learning_rate=optimization.LrConfig( type='polynomial', polynomial=optimization.PolynomialLrConfig( initial_learning_rate=1.5e-3, decay_steps=10000, end_learning_rate=1.5e-3)), warmup=optimization.WarmupConfig( type='linear', linear=optimization.LinearWarmupConfig( warmup_learning_rate=0))) exp_config = cfg.ExperimentConfig( task=task_config, trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=progressive_config, optimizer_config=optimization_config)) # Create a teacher model checkpoint. teacher_encoder = encoders.build_encoder( task_config.teacher_model.encoder) pretrainer_config = task_config.teacher_model if pretrainer_config.cls_heads: teacher_cls_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in pretrainer_config.cls_heads ] else: teacher_cls_heads = [] masked_lm = layers.MobileBertMaskedLM( embedding_table=teacher_encoder.get_embedding_table(), activation=tf_utils.get_activation( pretrainer_config.mlm_activation), initializer=tf.keras.initializers.TruncatedNormal( stddev=pretrainer_config.mlm_initializer_range), name='cls/predictions') teacher_pretrainer = models.BertPretrainerV2( encoder_network=teacher_encoder, classification_heads=teacher_cls_heads, customized_masked_lm=masked_lm) # The model variables will be created after the forward call. _ = teacher_pretrainer(teacher_pretrainer.inputs) teacher_pretrainer_ckpt = tf.train.Checkpoint( **teacher_pretrainer.checkpoint_items) teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt') teacher_pretrainer_ckpt.save(teacher_ckpt_path) exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir() return exp_config