示例#1
0
def create_mobilebert_pretrainer(bert_config):
    """Creates a BertPretrainerV2 that wraps MobileBERTEncoder model."""
    mobilebert_encoder = networks.MobileBERTEncoder(
        word_vocab_size=bert_config.vocab_size,
        word_embed_size=bert_config.embedding_size,
        type_vocab_size=bert_config.type_vocab_size,
        max_sequence_length=bert_config.max_position_embeddings,
        num_blocks=bert_config.num_hidden_layers,
        hidden_size=bert_config.hidden_size,
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_act_fn=tf_utils.get_activation(bert_config.hidden_act),
        hidden_dropout_prob=bert_config.hidden_dropout_prob,
        attention_probs_dropout_prob=bert_config.attention_probs_dropout_prob,
        intra_bottleneck_size=bert_config.intra_bottleneck_size,
        initializer_range=bert_config.initializer_range,
        use_bottleneck_attention=bert_config.use_bottleneck_attention,
        key_query_shared_bottleneck=bert_config.key_query_shared_bottleneck,
        num_feedforward_networks=bert_config.num_feedforward_networks,
        normalization_type=bert_config.normalization_type,
        classifier_activation=bert_config.classifier_activation)

    masked_lm = layers.MobileBertMaskedLM(
        embedding_table=mobilebert_encoder.get_embedding_table(),
        activation=tf_utils.get_activation(bert_config.hidden_act),
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        name="cls/predictions")

    pretrainer = models.BertPretrainerV2(encoder_network=mobilebert_encoder,
                                         customized_masked_lm=masked_lm)
    # Makes sure the pretrainer variables are created.
    _ = pretrainer(pretrainer.inputs)
    return pretrainer
示例#2
0
  def _build_pretrainer(self, pretrainer_cfg: bert.PretrainerConfig, name: str):
    """Builds pretrainer from config and encoder."""
    encoder = encoders.build_encoder(pretrainer_cfg.encoder)
    if pretrainer_cfg.cls_heads:
      cls_heads = [
          layers.ClassificationHead(**cfg.as_dict())
          for cfg in pretrainer_cfg.cls_heads
      ]
    else:
      cls_heads = []

    masked_lm = layers.MobileBertMaskedLM(
        embedding_table=encoder.get_embedding_table(),
        activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation),
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=pretrainer_cfg.mlm_initializer_range),
        name='cls/predictions')

    pretrainer = models.BertPretrainerV2(
        encoder_network=encoder,
        classification_heads=cls_heads,
        customized_masked_lm=masked_lm,
        name=name)
    return pretrainer
示例#3
0
    def prepare_config(self, teacher_block_num, student_block_num,
                       transfer_teacher_layers):
        # using small model for testing
        task_config = distillation.BertDistillationTaskConfig(
            teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=teacher_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='gelu'),
            student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=student_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='relu'),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10))

        # set only 1 step for each stage
        progressive_config = distillation.BertDistillationProgressiveConfig()
        progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
            transfer_teacher_layers)
        progressive_config.layer_wise_distill_config.num_steps = 1
        progressive_config.pretrain_distill_config.num_steps = 1

        optimization_config = optimization.OptimizationConfig(
            optimizer=optimization.OptimizerConfig(
                type='lamb',
                lamb=optimization.LAMBConfig(weight_decay_rate=0.0001,
                                             exclude_from_weight_decay=[
                                                 'LayerNorm', 'layer_norm',
                                                 'bias', 'no_norm'
                                             ])),
            learning_rate=optimization.LrConfig(
                type='polynomial',
                polynomial=optimization.PolynomialLrConfig(
                    initial_learning_rate=1.5e-3,
                    decay_steps=10000,
                    end_learning_rate=1.5e-3)),
            warmup=optimization.WarmupConfig(
                type='linear',
                linear=optimization.LinearWarmupConfig(
                    warmup_learning_rate=0)))

        exp_config = cfg.ExperimentConfig(
            task=task_config,
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(
                progressive=progressive_config,
                optimizer_config=optimization_config))

        # Create a teacher model checkpoint.
        teacher_encoder = encoders.build_encoder(
            task_config.teacher_model.encoder)
        pretrainer_config = task_config.teacher_model
        if pretrainer_config.cls_heads:
            teacher_cls_heads = [
                layers.ClassificationHead(**cfg.as_dict())
                for cfg in pretrainer_config.cls_heads
            ]
        else:
            teacher_cls_heads = []

        masked_lm = layers.MobileBertMaskedLM(
            embedding_table=teacher_encoder.get_embedding_table(),
            activation=tf_utils.get_activation(
                pretrainer_config.mlm_activation),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=pretrainer_config.mlm_initializer_range),
            name='cls/predictions')
        teacher_pretrainer = models.BertPretrainerV2(
            encoder_network=teacher_encoder,
            classification_heads=teacher_cls_heads,
            customized_masked_lm=masked_lm)

        # The model variables will be created after the forward call.
        _ = teacher_pretrainer(teacher_pretrainer.inputs)
        teacher_pretrainer_ckpt = tf.train.Checkpoint(
            **teacher_pretrainer.checkpoint_items)
        teacher_ckpt_path = os.path.join(self.get_temp_dir(),
                                         'teacher_model.ckpt')
        teacher_pretrainer_ckpt.save(teacher_ckpt_path)
        exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()

        return exp_config