コード例 #1
0
def get_exp_config():
    return cfg.ExperimentConfig(
        task=cfg.TaskConfig(model=bert.PretrainerConfig()),
        trainer=trainer_lib.ProgressiveTrainerConfig(
            export_checkpoint=True,
            export_checkpoint_interval=1,
            export_only_final_stage_ckpt=False))
コード例 #2
0
 def setUp(self):
   super(ProgressiveMaskedLMTest, self).setUp()
   self.task_config = progressive_masked_lm.ProgMaskedLMConfig(
       model=bert.PretrainerConfig(
           encoder=encoders.EncoderConfig(
               bert=encoders.BertEncoderConfig(vocab_size=30522,
                                               num_layers=2)),
           cls_heads=[
               bert.ClsHeadConfig(
                   inner_dim=10, num_classes=2, name="next_sentence")
           ]),
       train_data=pretrain_dataloader.BertPretrainDataConfig(
           input_path="dummy",
           max_predictions_per_seq=20,
           seq_length=128,
           global_batch_size=1),
       validation_data=pretrain_dataloader.BertPretrainDataConfig(
           input_path="dummy",
           max_predictions_per_seq=20,
           seq_length=128,
           global_batch_size=1),
       stage_list=[
           progressive_masked_lm.StackingStageConfig(
               num_layers=1, num_steps=4),
           progressive_masked_lm.StackingStageConfig(
               num_layers=2, num_steps=8),
           ],
       )
   self.exp_config = cfg.ExperimentConfig(
       task=self.task_config,
       trainer=prog_trainer_lib.ProgressiveTrainerConfig())
コード例 #3
0
def wmt_transformer_large_progressive() -> cfg.ExperimentConfig:
  """WMT Transformer Larger with progressive training.

  Please refer to
  tensorflow_models/official/nlp/data/train_sentencepiece.py
  to generate sentencepiece_model
  and pass
  --params_override=task.sentencepiece_model_path='YOUR_PATH'
  to the train script.
  """
  hidden_size = 1024
  train_steps = 300000
  token_batch_size = 24576
  encdecoder = translation.EncDecoder(
      num_attention_heads=16, intermediate_size=hidden_size * 4)
  config = cfg.ExperimentConfig(
      task=progressive_translation.ProgTranslationConfig(
          model=translation.ModelConfig(
              encoder=encdecoder,
              decoder=encdecoder,
              embedding_width=hidden_size,
              padded_decode=True,
              decode_max_length=100),
          train_data=wmt_dataloader.WMTDataConfig(
              tfds_name='wmt14_translate/de-en',
              tfds_split='train',
              src_lang='en',
              tgt_lang='de',
              is_training=True,
              global_batch_size=token_batch_size,
              static_batch=True,
              max_seq_length=64
          ),
          validation_data=wmt_dataloader.WMTDataConfig(
              tfds_name='wmt14_translate/de-en',
              tfds_split='test',
              src_lang='en',
              tgt_lang='de',
              is_training=False,
              global_batch_size=32,
              static_batch=True,
              max_seq_length=100,
          ),
          sentencepiece_model_path=None,
      ),
      trainer=prog_trainer_lib.ProgressiveTrainerConfig(
          train_steps=train_steps,
          validation_steps=-1,
          steps_per_loop=1000,
          summary_interval=1000,
          checkpoint_interval=5000,
          validation_interval=5000,
          optimizer_config=None,
      ),
      restrictions=[
          'task.train_data.is_training != None',
          'task.sentencepiece_model_path != None',
      ])
  return config
コード例 #4
0
 def setUp(self):
     super(ProgressiveTranslationTest, self).setUp()
     self._temp_dir = self.get_temp_dir()
     src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
     tgt_lines = ["dd cc a ef  g", "bcd ef a g", "gef cd ba"]
     self._record_input_path = os.path.join(self._temp_dir, "train.record")
     _generate_record_file(self._record_input_path, src_lines, tgt_lines)
     self._sentencepeice_input_path = os.path.join(self._temp_dir,
                                                   "inputs.txt")
     _generate_line_file(self._sentencepeice_input_path,
                         src_lines + tgt_lines)
     sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
     _train_sentencepiece(self._sentencepeice_input_path, 11,
                          sentencepeice_model_prefix)
     self._sentencepeice_model_path = "{}.model".format(
         sentencepeice_model_prefix)
     encdecoder = translation.EncDecoder(num_attention_heads=2,
                                         intermediate_size=8)
     self.task_config = progressive_translation.ProgTranslationConfig(
         model=translation.ModelConfig(encoder=encdecoder,
                                       decoder=encdecoder,
                                       embedding_width=8,
                                       padded_decode=True,
                                       decode_max_length=100),
         train_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             is_training=True,
             global_batch_size=24,
             static_batch=True,
             src_lang="en",
             tgt_lang="reverse_en",
             max_seq_length=12),
         validation_data=wmt_dataloader.WMTDataConfig(
             input_path=self._record_input_path,
             is_training=False,
             global_batch_size=2,
             static_batch=True,
             src_lang="en",
             tgt_lang="reverse_en",
             max_seq_length=12),
         sentencepiece_model_path=self._sentencepeice_model_path,
         stage_list=[
             progressive_translation.StackingStageConfig(
                 num_encoder_layers=1, num_decoder_layers=1, num_steps=4),
             progressive_translation.StackingStageConfig(
                 num_encoder_layers=2, num_decoder_layers=1, num_steps=8),
         ],
     )
     self.exp_config = cfg.ExperimentConfig(
         task=self.task_config,
         trainer=prog_trainer_lib.ProgressiveTrainerConfig())
コード例 #5
0
def get_exp_config():
    """Get ExperimentConfig."""
    params = cfg.ExperimentConfig(
        task=distillation.BertDistillationTaskConfig(
            train_data=pretrain_dataloader.BertPretrainDataConfig(),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                is_training=False)),
        trainer=prog_trainer_lib.ProgressiveTrainerConfig(
            progressive=distillation.BertDistillationProgressiveConfig(),
            optimizer_config=optimization_config,
            train_steps=740000,
            checkpoint_interval=20000))

    return config_override(params, FLAGS)
コード例 #6
0
  def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
    config = cfg.ExperimentConfig(
        task=cfg.TaskConfig(
            model=bert.PretrainerConfig()),
        runtime=cfg.RuntimeConfig(
            mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
        trainer=trainer_lib.ProgressiveTrainerConfig(
            export_checkpoint=True,
            export_checkpoint_interval=1,
            export_only_final_stage_ckpt=False))
    task = TestPolicy(None, config.task)
    trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
    if mixed_precision_dtype != 'float16':
      self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
    elif mixed_precision_dtype == 'float16' and loss_scale is None:
      self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)

    metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
    self.assertIn('training_loss', metrics)
コード例 #7
0
def get_exp_config():
  """Get ExperimentConfig."""

  params = cfg.ExperimentConfig(
      task=masked_lm.MaskedLMConfig(
          train_data=pretrain_dataloader.BertPretrainDataConfig(),
          small_train_data=pretrain_dataloader.BertPretrainDataConfig(),
          validation_data=pretrain_dataloader.BertPretrainDataConfig(
              is_training=False)),
      trainer=prog_trainer_lib.ProgressiveTrainerConfig(
          progressive=masked_lm.ProgStackingConfig(),
          optimizer_config=BertOptimizationConfig(),
          train_steps=1000000),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return utils.config_override(params, FLAGS)
    def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
        model_dir = self.get_temp_dir()
        experiment_config = cfg.ExperimentConfig(
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
            task=ProgTaskConfig())
        experiment_config = params_dict.override_params_dict(experiment_config,
                                                             self._test_config,
                                                             is_strict=False)

        with distribution_strategy.scope():
            task = task_factory.get_task(experiment_config.task,
                                         logging_dir=model_dir)

        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode=flag_mode,
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)

        if run_post_eval:
            self.assertNotEmpty(logs)
        else:
            self.assertEmpty(logs)

        if flag_mode == 'eval':
            return
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
        # Tests continuous evaluation.
        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode='continuous_eval',
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)
        print(logs)
コード例 #9
0
    def setUp(self):
        super(DistillationTest, self).setUp()
        # using small model for testing
        self.model_block_num = 2
        self.task_config = distillation.BertDistillationTaskConfig(
            teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=self.model_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='gelu'),
            student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=self.model_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='relu'),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10))

        # set only 1 step for each stage
        progressive_config = distillation.BertDistillationProgressiveConfig()
        progressive_config.layer_wise_distill_config.num_steps = 1
        progressive_config.pretrain_distill_config.num_steps = 1

        optimization_config = optimization.OptimizationConfig(
            optimizer=optimization.OptimizerConfig(
                type='lamb',
                lamb=optimization.LAMBConfig(weight_decay_rate=0.0001,
                                             exclude_from_weight_decay=[
                                                 'LayerNorm', 'layer_norm',
                                                 'bias', 'no_norm'
                                             ])),
            learning_rate=optimization.LrConfig(
                type='polynomial',
                polynomial=optimization.PolynomialLrConfig(
                    initial_learning_rate=1.5e-3,
                    decay_steps=10000,
                    end_learning_rate=1.5e-3)),
            warmup=optimization.WarmupConfig(
                type='linear',
                linear=optimization.LinearWarmupConfig(
                    warmup_learning_rate=0)))

        self.exp_config = cfg.ExperimentConfig(
            task=self.task_config,
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(
                progressive=progressive_config,
                optimizer_config=optimization_config))

        # Create a teacher model checkpoint.
        teacher_encoder = encoders.build_encoder(
            self.task_config.teacher_model.encoder)
        pretrainer_config = self.task_config.teacher_model
        if pretrainer_config.cls_heads:
            teacher_cls_heads = [
                layers.ClassificationHead(**cfg.as_dict())
                for cfg in pretrainer_config.cls_heads
            ]
        else:
            teacher_cls_heads = []

        masked_lm = layers.MobileBertMaskedLM(
            embedding_table=teacher_encoder.get_embedding_table(),
            activation=tf_utils.get_activation(
                pretrainer_config.mlm_activation),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=pretrainer_config.mlm_initializer_range),
            name='cls/predictions')
        teacher_pretrainer = models.BertPretrainerV2(
            encoder_network=teacher_encoder,
            classification_heads=teacher_cls_heads,
            customized_masked_lm=masked_lm)

        # The model variables will be created after the forward call.
        _ = teacher_pretrainer(teacher_pretrainer.inputs)
        teacher_pretrainer_ckpt = tf.train.Checkpoint(
            **teacher_pretrainer.checkpoint_items)
        teacher_ckpt_path = os.path.join(self.get_temp_dir(),
                                         'teacher_model.ckpt')
        teacher_pretrainer_ckpt.save(teacher_ckpt_path)
        self.task_config.teacher_model_init_checkpoint = self.get_temp_dir()