def get_exp_config(): return cfg.ExperimentConfig( task=cfg.TaskConfig(model=bert.PretrainerConfig()), trainer=trainer_lib.ProgressiveTrainerConfig( export_checkpoint=True, export_checkpoint_interval=1, export_only_final_stage_ckpt=False))
def setUp(self): super(ProgressiveMaskedLMTest, self).setUp() self.task_config = progressive_masked_lm.ProgMaskedLMConfig( model=bert.PretrainerConfig( encoder=encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=2)), cls_heads=[ bert.ClsHeadConfig( inner_dim=10, num_classes=2, name="next_sentence") ]), train_data=pretrain_dataloader.BertPretrainDataConfig( input_path="dummy", max_predictions_per_seq=20, seq_length=128, global_batch_size=1), validation_data=pretrain_dataloader.BertPretrainDataConfig( input_path="dummy", max_predictions_per_seq=20, seq_length=128, global_batch_size=1), stage_list=[ progressive_masked_lm.StackingStageConfig( num_layers=1, num_steps=4), progressive_masked_lm.StackingStageConfig( num_layers=2, num_steps=8), ], ) self.exp_config = cfg.ExperimentConfig( task=self.task_config, trainer=prog_trainer_lib.ProgressiveTrainerConfig())
def wmt_transformer_large_progressive() -> cfg.ExperimentConfig: """WMT Transformer Larger with progressive training. Please refer to tensorflow_models/official/nlp/data/train_sentencepiece.py to generate sentencepiece_model and pass --params_override=task.sentencepiece_model_path='YOUR_PATH' to the train script. """ hidden_size = 1024 train_steps = 300000 token_batch_size = 24576 encdecoder = translation.EncDecoder( num_attention_heads=16, intermediate_size=hidden_size * 4) config = cfg.ExperimentConfig( task=progressive_translation.ProgTranslationConfig( model=translation.ModelConfig( encoder=encdecoder, decoder=encdecoder, embedding_width=hidden_size, padded_decode=True, decode_max_length=100), train_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='train', src_lang='en', tgt_lang='de', is_training=True, global_batch_size=token_batch_size, static_batch=True, max_seq_length=64 ), validation_data=wmt_dataloader.WMTDataConfig( tfds_name='wmt14_translate/de-en', tfds_split='test', src_lang='en', tgt_lang='de', is_training=False, global_batch_size=32, static_batch=True, max_seq_length=100, ), sentencepiece_model_path=None, ), trainer=prog_trainer_lib.ProgressiveTrainerConfig( train_steps=train_steps, validation_steps=-1, steps_per_loop=1000, summary_interval=1000, checkpoint_interval=5000, validation_interval=5000, optimizer_config=None, ), restrictions=[ 'task.train_data.is_training != None', 'task.sentencepiece_model_path != None', ]) return config
def setUp(self): super(ProgressiveTranslationTest, self).setUp() self._temp_dir = self.get_temp_dir() src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"] tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"] self._record_input_path = os.path.join(self._temp_dir, "train.record") _generate_record_file(self._record_input_path, src_lines, tgt_lines) self._sentencepeice_input_path = os.path.join(self._temp_dir, "inputs.txt") _generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines) sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp") _train_sentencepiece(self._sentencepeice_input_path, 11, sentencepeice_model_prefix) self._sentencepeice_model_path = "{}.model".format( sentencepeice_model_prefix) encdecoder = translation.EncDecoder(num_attention_heads=2, intermediate_size=8) self.task_config = progressive_translation.ProgTranslationConfig( model=translation.ModelConfig(encoder=encdecoder, decoder=encdecoder, embedding_width=8, padded_decode=True, decode_max_length=100), train_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, is_training=True, global_batch_size=24, static_batch=True, src_lang="en", tgt_lang="reverse_en", max_seq_length=12), validation_data=wmt_dataloader.WMTDataConfig( input_path=self._record_input_path, is_training=False, global_batch_size=2, static_batch=True, src_lang="en", tgt_lang="reverse_en", max_seq_length=12), sentencepiece_model_path=self._sentencepeice_model_path, stage_list=[ progressive_translation.StackingStageConfig( num_encoder_layers=1, num_decoder_layers=1, num_steps=4), progressive_translation.StackingStageConfig( num_encoder_layers=2, num_decoder_layers=1, num_steps=8), ], ) self.exp_config = cfg.ExperimentConfig( task=self.task_config, trainer=prog_trainer_lib.ProgressiveTrainerConfig())
def get_exp_config(): """Get ExperimentConfig.""" params = cfg.ExperimentConfig( task=distillation.BertDistillationTaskConfig( train_data=pretrain_dataloader.BertPretrainDataConfig(), validation_data=pretrain_dataloader.BertPretrainDataConfig( is_training=False)), trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=distillation.BertDistillationProgressiveConfig(), optimizer_config=optimization_config, train_steps=740000, checkpoint_interval=20000)) return config_override(params, FLAGS)
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale): config = cfg.ExperimentConfig( task=cfg.TaskConfig( model=bert.PretrainerConfig()), runtime=cfg.RuntimeConfig( mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale), trainer=trainer_lib.ProgressiveTrainerConfig( export_checkpoint=True, export_checkpoint_interval=1, export_only_final_stage_ckpt=False)) task = TestPolicy(None, config.task) trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir()) if mixed_precision_dtype != 'float16': self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) elif mixed_precision_dtype == 'float16' and loss_scale is None: self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) self.assertIn('training_loss', metrics)
def get_exp_config(): """Get ExperimentConfig.""" params = cfg.ExperimentConfig( task=masked_lm.MaskedLMConfig( train_data=pretrain_dataloader.BertPretrainDataConfig(), small_train_data=pretrain_dataloader.BertPretrainDataConfig(), validation_data=pretrain_dataloader.BertPretrainDataConfig( is_training=False)), trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=masked_lm.ProgStackingConfig(), optimizer_config=BertOptimizationConfig(), train_steps=1000000), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return utils.config_override(params, FLAGS)
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval): model_dir = self.get_temp_dir() experiment_config = cfg.ExperimentConfig( trainer=prog_trainer_lib.ProgressiveTrainerConfig(), task=ProgTaskConfig()) experiment_config = params_dict.override_params_dict(experiment_config, self._test_config, is_strict=False) with distribution_strategy.scope(): task = task_factory.get_task(experiment_config.task, logging_dir=model_dir) _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=flag_mode, params=experiment_config, model_dir=model_dir, run_post_eval=run_post_eval) if run_post_eval: self.assertNotEmpty(logs) else: self.assertEmpty(logs) if flag_mode == 'eval': return self.assertNotEmpty( tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) # Tests continuous evaluation. _, logs = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='continuous_eval', params=experiment_config, model_dir=model_dir, run_post_eval=run_post_eval) print(logs)
def setUp(self): super(DistillationTest, self).setUp() # using small model for testing self.model_block_num = 2 self.task_config = distillation.BertDistillationTaskConfig( teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig( type='mobilebert', mobilebert=encoders.MobileBertEncoderConfig( num_blocks=self.model_block_num)), cls_heads=[ bert.ClsHeadConfig( inner_dim=256, num_classes=2, dropout_rate=0.1, name='next_sentence') ], mlm_activation='gelu'), student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig( type='mobilebert', mobilebert=encoders.MobileBertEncoderConfig( num_blocks=self.model_block_num)), cls_heads=[ bert.ClsHeadConfig( inner_dim=256, num_classes=2, dropout_rate=0.1, name='next_sentence') ], mlm_activation='relu'), train_data=pretrain_dataloader.BertPretrainDataConfig( input_path='dummy', max_predictions_per_seq=76, seq_length=512, global_batch_size=10), validation_data=pretrain_dataloader.BertPretrainDataConfig( input_path='dummy', max_predictions_per_seq=76, seq_length=512, global_batch_size=10)) # set only 1 step for each stage progressive_config = distillation.BertDistillationProgressiveConfig() progressive_config.layer_wise_distill_config.num_steps = 1 progressive_config.pretrain_distill_config.num_steps = 1 optimization_config = optimization.OptimizationConfig( optimizer=optimization.OptimizerConfig( type='lamb', lamb=optimization.LAMBConfig(weight_decay_rate=0.0001, exclude_from_weight_decay=[ 'LayerNorm', 'layer_norm', 'bias', 'no_norm' ])), learning_rate=optimization.LrConfig( type='polynomial', polynomial=optimization.PolynomialLrConfig( initial_learning_rate=1.5e-3, decay_steps=10000, end_learning_rate=1.5e-3)), warmup=optimization.WarmupConfig( type='linear', linear=optimization.LinearWarmupConfig( warmup_learning_rate=0))) self.exp_config = cfg.ExperimentConfig( task=self.task_config, trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=progressive_config, optimizer_config=optimization_config)) # Create a teacher model checkpoint. teacher_encoder = encoders.build_encoder( self.task_config.teacher_model.encoder) pretrainer_config = self.task_config.teacher_model if pretrainer_config.cls_heads: teacher_cls_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in pretrainer_config.cls_heads ] else: teacher_cls_heads = [] masked_lm = layers.MobileBertMaskedLM( embedding_table=teacher_encoder.get_embedding_table(), activation=tf_utils.get_activation( pretrainer_config.mlm_activation), initializer=tf.keras.initializers.TruncatedNormal( stddev=pretrainer_config.mlm_initializer_range), name='cls/predictions') teacher_pretrainer = models.BertPretrainerV2( encoder_network=teacher_encoder, classification_heads=teacher_cls_heads, customized_masked_lm=masked_lm) # The model variables will be created after the forward call. _ = teacher_pretrainer(teacher_pretrainer.inputs) teacher_pretrainer_ckpt = tf.train.Checkpoint( **teacher_pretrainer.checkpoint_items) teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt') teacher_pretrainer_ckpt.save(teacher_ckpt_path) self.task_config.teacher_model_init_checkpoint = self.get_temp_dir()