def test_default_initialization(self):
     """Initializes pretrainer model from stratch."""
     pretrainer = model_builder.build_bert_pretrainer(
         pretrainer_cfg=self.pretrainer_config, name='test_model')
     # Makes sure the pretrainer variables are created.
     _ = pretrainer(pretrainer.inputs)
     self.assertEqual(pretrainer.name, 'test_model')
     encoder = pretrainer.encoder_network
     default_number_layer = encoders.MobileBertEncoderConfig().num_blocks
     encoder_transformer_layer_counter = 0
     for layer in encoder.layers:
         if isinstance(layer, modeling.layers.MobileBertTransformer):
             encoder_transformer_layer_counter += 1
     self.assertEqual(default_number_layer,
                      encoder_transformer_layer_counter)
 def test_initialization_with_mlm(self):
     """Initializes pretrainer model with an existing MLM head."""
     embedding = modeling.layers.MobileBertEmbedding(
         word_vocab_size=30522,
         word_embed_size=128,
         type_vocab_size=2,
         output_embed_size=encoders.MobileBertEncoderConfig().hidden_size)
     dummy_input = tf.keras.layers.Input(shape=(None, ), dtype=tf.int32)
     _ = embedding(dummy_input)
     embedding_table = embedding.word_embedding.embeddings
     mlm_layer = modeling.layers.MobileBertMaskedLM(
         embedding_table=embedding_table)
     pretrainer = model_builder.build_bert_pretrainer(
         pretrainer_cfg=self.pretrainer_config, masked_lm=mlm_layer)
     mlm_network = pretrainer.masked_lm
     self.assertEqual(mlm_network, mlm_layer)
Exemple #3
0
    def prepare_config(self, teacher_block_num, student_block_num,
                       transfer_teacher_layers):
        # using small model for testing
        task_config = distillation.BertDistillationTaskConfig(
            teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=teacher_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='gelu'),
            student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=student_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='relu'),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10))

        # set only 1 step for each stage
        progressive_config = distillation.BertDistillationProgressiveConfig()
        progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
            transfer_teacher_layers)
        progressive_config.layer_wise_distill_config.num_steps = 1
        progressive_config.pretrain_distill_config.num_steps = 1

        optimization_config = optimization.OptimizationConfig(
            optimizer=optimization.OptimizerConfig(
                type='lamb',
                lamb=optimization.LAMBConfig(weight_decay_rate=0.0001,
                                             exclude_from_weight_decay=[
                                                 'LayerNorm', 'layer_norm',
                                                 'bias', 'no_norm'
                                             ])),
            learning_rate=optimization.LrConfig(
                type='polynomial',
                polynomial=optimization.PolynomialLrConfig(
                    initial_learning_rate=1.5e-3,
                    decay_steps=10000,
                    end_learning_rate=1.5e-3)),
            warmup=optimization.WarmupConfig(
                type='linear',
                linear=optimization.LinearWarmupConfig(
                    warmup_learning_rate=0)))

        exp_config = cfg.ExperimentConfig(
            task=task_config,
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(
                progressive=progressive_config,
                optimizer_config=optimization_config))

        # Create a teacher model checkpoint.
        teacher_encoder = encoders.build_encoder(
            task_config.teacher_model.encoder)
        pretrainer_config = task_config.teacher_model
        if pretrainer_config.cls_heads:
            teacher_cls_heads = [
                layers.ClassificationHead(**cfg.as_dict())
                for cfg in pretrainer_config.cls_heads
            ]
        else:
            teacher_cls_heads = []

        masked_lm = layers.MobileBertMaskedLM(
            embedding_table=teacher_encoder.get_embedding_table(),
            activation=tf_utils.get_activation(
                pretrainer_config.mlm_activation),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=pretrainer_config.mlm_initializer_range),
            name='cls/predictions')
        teacher_pretrainer = models.BertPretrainerV2(
            encoder_network=teacher_encoder,
            classification_heads=teacher_cls_heads,
            customized_masked_lm=masked_lm)

        # The model variables will be created after the forward call.
        _ = teacher_pretrainer(teacher_pretrainer.inputs)
        teacher_pretrainer_ckpt = tf.train.Checkpoint(
            **teacher_pretrainer.checkpoint_items)
        teacher_ckpt_path = os.path.join(self.get_temp_dir(),
                                         'teacher_model.ckpt')
        teacher_pretrainer_ckpt.save(teacher_ckpt_path)
        exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()

        return exp_config