Beispiel #1
0
    def prepare_config(self, teacher_block_num, student_block_num,
                       transfer_teacher_layers):
        # using small model for testing
        task_config = distillation.BertDistillationTaskConfig(
            teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=teacher_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='gelu'),
            student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=student_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='relu'),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10))

        # set only 1 step for each stage
        progressive_config = distillation.BertDistillationProgressiveConfig()
        progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
            transfer_teacher_layers)
        progressive_config.layer_wise_distill_config.num_steps = 1
        progressive_config.pretrain_distill_config.num_steps = 1

        optimization_config = optimization.OptimizationConfig(
            optimizer=optimization.OptimizerConfig(
                type='lamb',
                lamb=optimization.LAMBConfig(weight_decay_rate=0.0001,
                                             exclude_from_weight_decay=[
                                                 'LayerNorm', 'layer_norm',
                                                 'bias', 'no_norm'
                                             ])),
            learning_rate=optimization.LrConfig(
                type='polynomial',
                polynomial=optimization.PolynomialLrConfig(
                    initial_learning_rate=1.5e-3,
                    decay_steps=10000,
                    end_learning_rate=1.5e-3)),
            warmup=optimization.WarmupConfig(
                type='linear',
                linear=optimization.LinearWarmupConfig(
                    warmup_learning_rate=0)))

        exp_config = cfg.ExperimentConfig(
            task=task_config,
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(
                progressive=progressive_config,
                optimizer_config=optimization_config))

        # Create a teacher model checkpoint.
        teacher_encoder = encoders.build_encoder(
            task_config.teacher_model.encoder)
        pretrainer_config = task_config.teacher_model
        if pretrainer_config.cls_heads:
            teacher_cls_heads = [
                layers.ClassificationHead(**cfg.as_dict())
                for cfg in pretrainer_config.cls_heads
            ]
        else:
            teacher_cls_heads = []

        masked_lm = layers.MobileBertMaskedLM(
            embedding_table=teacher_encoder.get_embedding_table(),
            activation=tf_utils.get_activation(
                pretrainer_config.mlm_activation),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=pretrainer_config.mlm_initializer_range),
            name='cls/predictions')
        teacher_pretrainer = models.BertPretrainerV2(
            encoder_network=teacher_encoder,
            classification_heads=teacher_cls_heads,
            customized_masked_lm=masked_lm)

        # The model variables will be created after the forward call.
        _ = teacher_pretrainer(teacher_pretrainer.inputs)
        teacher_pretrainer_ckpt = tf.train.Checkpoint(
            **teacher_pretrainer.checkpoint_items)
        teacher_ckpt_path = os.path.join(self.get_temp_dir(),
                                         'teacher_model.ckpt')
        teacher_pretrainer_ckpt.save(teacher_ckpt_path)
        exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()

        return exp_config
optimization_config = optimization.OptimizationConfig(
    optimizer=optimization.OptimizerConfig(
        type='lamb',
        lamb=optimization.LAMBConfig(
            weight_decay_rate=0.01,
            exclude_from_weight_decay=['LayerNorm', 'bias', 'norm'],
            clipnorm=1.0)),
    learning_rate=optimization.LrConfig(
        type='polynomial',
        polynomial=optimization.PolynomialLrConfig(
            initial_learning_rate=1.5e-3,
            decay_steps=10000,
            end_learning_rate=1.5e-3)),
    warmup=optimization.WarmupConfig(
        type='linear',
        linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))


# copy from progressive/utils.py due to the private visibility issue.
def config_override(params, flags_obj):
    """Override ExperimentConfig according to flags."""
    # Change runtime.tpu to the real tpu.
    params.override({'runtime': {
        'tpu': flags_obj.tpu,
    }})

    # Get the first level of override from `--config_file`.
    #   `--config_file` is typically used as a template that specifies the common
    #   override for a particular experiment.
    for config_file in flags_obj.config_file or []:
        params = hyperparams.override_params_dict(params,