def _build_pretrainer( config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer: """Instantiates ElectraPretrainer from the config.""" generator_encoder_cfg = config.generator_encoder discriminator_encoder_cfg = config.discriminator_encoder # Copy discriminator's embeddings to generator for easier model serialization. discriminator_network = encoders.build_encoder(discriminator_encoder_cfg) if config.tie_embeddings: embedding_layer = discriminator_network.get_embedding_layer() generator_network = encoders.build_encoder( generator_encoder_cfg, embedding_layer=embedding_layer) else: generator_network = encoders.build_encoder(generator_encoder_cfg) generator_encoder_cfg = generator_encoder_cfg.get() return models.ElectraPretrainer( generator_network=generator_network, discriminator_network=discriminator_network, vocab_size=generator_encoder_cfg.vocab_size, num_classes=config.num_classes, sequence_length=config.sequence_length, num_token_predictions=config.num_masked_tokens, mlm_activation=tf_utils.get_activation( generator_encoder_cfg.hidden_activation), mlm_initializer=tf.keras.initializers.TruncatedNormal( stddev=generator_encoder_cfg.initializer_range), classification_heads=[ layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads ], disallow_correct=config.disallow_correct)
def build_model(self, params=None): config = params or self.task_config.model encoder_cfg = config.encoder encoder_network = self._build_encoder(encoder_cfg) cls_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads ] if config.cls_heads else [] return models.BertPretrainerV2( mlm_activation=tf_utils.get_activation(config.mlm_activation), mlm_initializer=tf.keras.initializers.TruncatedNormal( stddev=config.mlm_initializer_range), encoder_network=encoder_network, classification_heads=cls_heads)
def _build_pretrainer(self, pretrainer_cfg: bert.PretrainerConfig, name: str): """Builds pretrainer from config and encoder.""" encoder = encoders.build_encoder(pretrainer_cfg.encoder) if pretrainer_cfg.cls_heads: cls_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in pretrainer_cfg.cls_heads ] else: cls_heads = [] masked_lm = layers.MobileBertMaskedLM( embedding_table=encoder.get_embedding_table(), activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation), initializer=tf.keras.initializers.TruncatedNormal( stddev=pretrainer_cfg.mlm_initializer_range), name='cls/predictions') pretrainer = models.BertPretrainerV2( encoder_network=encoder, classification_heads=cls_heads, customized_masked_lm=masked_lm, name=name) return pretrainer
def prepare_config(self, teacher_block_num, student_block_num, transfer_teacher_layers): # using small model for testing task_config = distillation.BertDistillationTaskConfig( teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig( type='mobilebert', mobilebert=encoders.MobileBertEncoderConfig( num_blocks=teacher_block_num)), cls_heads=[ bert.ClsHeadConfig( inner_dim=256, num_classes=2, dropout_rate=0.1, name='next_sentence') ], mlm_activation='gelu'), student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig( type='mobilebert', mobilebert=encoders.MobileBertEncoderConfig( num_blocks=student_block_num)), cls_heads=[ bert.ClsHeadConfig( inner_dim=256, num_classes=2, dropout_rate=0.1, name='next_sentence') ], mlm_activation='relu'), train_data=pretrain_dataloader.BertPretrainDataConfig( input_path='dummy', max_predictions_per_seq=76, seq_length=512, global_batch_size=10), validation_data=pretrain_dataloader.BertPretrainDataConfig( input_path='dummy', max_predictions_per_seq=76, seq_length=512, global_batch_size=10)) # set only 1 step for each stage progressive_config = distillation.BertDistillationProgressiveConfig() progressive_config.layer_wise_distill_config.transfer_teacher_layers = ( transfer_teacher_layers) progressive_config.layer_wise_distill_config.num_steps = 1 progressive_config.pretrain_distill_config.num_steps = 1 optimization_config = optimization.OptimizationConfig( optimizer=optimization.OptimizerConfig( type='lamb', lamb=optimization.LAMBConfig(weight_decay_rate=0.0001, exclude_from_weight_decay=[ 'LayerNorm', 'layer_norm', 'bias', 'no_norm' ])), learning_rate=optimization.LrConfig( type='polynomial', polynomial=optimization.PolynomialLrConfig( initial_learning_rate=1.5e-3, decay_steps=10000, end_learning_rate=1.5e-3)), warmup=optimization.WarmupConfig( type='linear', linear=optimization.LinearWarmupConfig( warmup_learning_rate=0))) exp_config = cfg.ExperimentConfig( task=task_config, trainer=prog_trainer_lib.ProgressiveTrainerConfig( progressive=progressive_config, optimizer_config=optimization_config)) # Create a teacher model checkpoint. teacher_encoder = encoders.build_encoder( task_config.teacher_model.encoder) pretrainer_config = task_config.teacher_model if pretrainer_config.cls_heads: teacher_cls_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in pretrainer_config.cls_heads ] else: teacher_cls_heads = [] masked_lm = layers.MobileBertMaskedLM( embedding_table=teacher_encoder.get_embedding_table(), activation=tf_utils.get_activation( pretrainer_config.mlm_activation), initializer=tf.keras.initializers.TruncatedNormal( stddev=pretrainer_config.mlm_initializer_range), name='cls/predictions') teacher_pretrainer = models.BertPretrainerV2( encoder_network=teacher_encoder, classification_heads=teacher_cls_heads, customized_masked_lm=masked_lm) # The model variables will be created after the forward call. _ = teacher_pretrainer(teacher_pretrainer.inputs) teacher_pretrainer_ckpt = tf.train.Checkpoint( **teacher_pretrainer.checkpoint_items) teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt') teacher_pretrainer_ckpt.save(teacher_ckpt_path) exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir() return exp_config