Beispiel #1
0
def _build_pretrainer(
        config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer:
    """Instantiates ElectraPretrainer from the config."""
    generator_encoder_cfg = config.generator_encoder
    discriminator_encoder_cfg = config.discriminator_encoder
    # Copy discriminator's embeddings to generator for easier model serialization.
    discriminator_network = encoders.build_encoder(discriminator_encoder_cfg)
    if config.tie_embeddings:
        embedding_layer = discriminator_network.get_embedding_layer()
        generator_network = encoders.build_encoder(
            generator_encoder_cfg, embedding_layer=embedding_layer)
    else:
        generator_network = encoders.build_encoder(generator_encoder_cfg)

    generator_encoder_cfg = generator_encoder_cfg.get()
    return models.ElectraPretrainer(
        generator_network=generator_network,
        discriminator_network=discriminator_network,
        vocab_size=generator_encoder_cfg.vocab_size,
        num_classes=config.num_classes,
        sequence_length=config.sequence_length,
        num_token_predictions=config.num_masked_tokens,
        mlm_activation=tf_utils.get_activation(
            generator_encoder_cfg.hidden_activation),
        mlm_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=generator_encoder_cfg.initializer_range),
        classification_heads=[
            layers.ClassificationHead(**cfg.as_dict())
            for cfg in config.cls_heads
        ],
        disallow_correct=config.disallow_correct)
  def test_encoder_from_yaml(self):
    config = encoders.EncoderConfig(
        type="bert", bert=encoders.BertEncoderConfig(num_layers=1))
    encoder = encoders.build_encoder(config)
    ckpt = tf.train.Checkpoint(encoder=encoder)
    ckpt_path = ckpt.save(self.get_temp_dir() + "/ckpt")
    params_save_path = os.path.join(self.get_temp_dir(), "params.yaml")
    hyperparams.save_params_dict_to_yaml(config, params_save_path)

    retored_cfg = encoders.EncoderConfig.from_yaml(params_save_path)
    retored_encoder = encoders.build_encoder(retored_cfg)
    status = tf.train.Checkpoint(encoder=retored_encoder).restore(ckpt_path)
    status.assert_consumed()
Beispiel #3
0
    def build_model(self):
        encoder_network = encoders.build_encoder(
            self.task_config.model.encoder)
        preprocess_dict = {}
        scorer = TFRBertScorer(
            encoder=encoder_network,
            bert_output_dropout=self.task_config.model.dropout_rate)

        example_feature_spec = {
            'input_word_ids': tf.io.FixedLenFeature(shape=(None, ),
                                                    dtype=tf.int64),
            'input_mask': tf.io.FixedLenFeature(shape=(None, ),
                                                dtype=tf.int64),
            'input_type_ids': tf.io.FixedLenFeature(shape=(None, ),
                                                    dtype=tf.int64)
        }
        context_feature_spec = {}

        model_builder = TFRBertModelBuilder(
            input_creator=tfr_model.FeatureSpecInputCreator(
                context_feature_spec, example_feature_spec),
            preprocessor=tfr_model.PreprocessorWithSpec(preprocess_dict),
            scorer=scorer,
            mask_feature_name=self._task_config.train_data.mask_feature_name,
            name='tfrbert_model')
        return model_builder.build()
Beispiel #4
0
    def test_tfr_bert_model_builder(self):
        encoder_config = encoders.EncoderConfig(
            bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
        encoder_network = encoders.build_encoder(encoder_config)
        preprocess_dict = {}
        scorer = tfrbert_task.TFRBertScorer(encoder=encoder_network,
                                            bert_output_dropout=0.1)

        example_feature_spec = {
            'input_word_ids': tf.io.FixedLenFeature(shape=(None, ),
                                                    dtype=tf.int64),
            'input_mask': tf.io.FixedLenFeature(shape=(None, ),
                                                dtype=tf.int64),
            'input_type_ids': tf.io.FixedLenFeature(shape=(None, ),
                                                    dtype=tf.int64)
        }
        context_feature_spec = {}

        model_builder = tfrbert_task.TFRBertModelBuilder(
            input_creator=tfr_model.FeatureSpecInputCreator(
                context_feature_spec, example_feature_spec),
            preprocessor=tfr_model.PreprocessorWithSpec(preprocess_dict),
            scorer=scorer,
            mask_feature_name='example_list_mask',
            name='tfrbert_model')
        model = model_builder.build()

        output = model(self._create_input_data())
        self.assertAllEqual(output.shape.as_list(), [12, 10])
Beispiel #5
0
 def _create_bert_ckpt(self):
     config = encoders.EncoderConfig(
         type='bert', bert=encoders.BertEncoderConfig(num_layers=1))
     encoder = encoders.build_encoder(config)
     ckpt = tf.train.Checkpoint(encoder=encoder)
     ckpt_path = ckpt.save(os.path.join(self._logging_dir, 'ckpt'))
     return ckpt_path
Beispiel #6
0
 def test_build_teams(self):
     config = encoders.EncoderConfig(
         type="any", any=teams.TeamsEncoderConfig(num_layers=1))
     encoder = encoders.build_encoder(config)
     self.assertIsInstance(encoder, networks.EncoderScaffold)
     self.assertIsInstance(encoder.embedding_network,
                           networks.PackedSequenceEmbedding)
Beispiel #7
0
 def test_initialization_with_encoder(self):
     """Initializes pretrainer model with an existing encoder network."""
     encoder = encoders.build_encoder(config=encoders.EncoderConfig(
         type='mobilebert'))
     pretrainer = model_builder.build_bert_pretrainer(
         pretrainer_cfg=self.pretrainer_config, encoder=encoder)
     encoder_network = pretrainer.encoder_network
     self.assertEqual(encoder_network, encoder)
Beispiel #8
0
 def _export_bert_tfhub(self):
   encoder = encoders.build_encoder(
       encoders.EncoderConfig(
           bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
   encoder_inputs_dict = {x.name: x for x in encoder.inputs}
   encoder_output_dict = encoder(encoder_inputs_dict)
   core_model = tf.keras.Model(
       inputs=encoder_inputs_dict, outputs=encoder_output_dict)
   hub_destination = os.path.join(self.get_temp_dir(), "hub")
   core_model.save(hub_destination, include_optimizer=False, save_format="tf")
   return hub_destination
Beispiel #9
0
 def build_model(self):
     if self._hub_module:
         encoder_network = utils.get_encoder_from_hub(self._hub_module)
     else:
         encoder_network = encoders.build_encoder(
             self.task_config.model.encoder)
     encoder_cfg = self.task_config.model.encoder.get()
     # Currently, we only supports bert-style question answering finetuning.
     return models.BertSpanLabeler(
         network=encoder_network,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range))
    def test_copy_pooler_dense_to_encoder(self):
        encoder_config = encoders.EncoderConfig(
            type="bert",
            bert=encoders.BertEncoderConfig(hidden_size=24,
                                            intermediate_size=48,
                                            num_layers=2))
        cls_heads = [
            layers.ClassificationHead(inner_dim=24,
                                      num_classes=2,
                                      name="next_sentence")
        ]
        encoder = encoders.build_encoder(encoder_config)
        pretrainer = models.BertPretrainerV2(
            encoder_network=encoder,
            classification_heads=cls_heads,
            mlm_activation=tf_utils.get_activation(
                encoder_config.get().hidden_activation))
        # Makes sure the pretrainer variables are created.
        _ = pretrainer(pretrainer.inputs)
        checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))

        vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
            self.get_temp_dir(), use_sp_model=True)
        export_path = os.path.join(self.get_temp_dir(), "hub")
        export_tfhub_lib.export_model(
            export_path=export_path,
            encoder_config=encoder_config,
            model_checkpoint_path=tf.train.latest_checkpoint(
                model_checkpoint_dir),
            with_mlm=True,
            copy_pooler_dense_to_encoder=True,
            vocab_file=vocab_file,
            sp_model_file=sp_model_file,
            do_lower_case=True)
        # Restores a hub KerasLayer.
        hub_layer = hub.KerasLayer(export_path, trainable=True)
        dummy_ids = np.zeros((2, 10), dtype=np.int32)
        input_dict = dict(input_word_ids=dummy_ids,
                          input_mask=dummy_ids,
                          input_type_ids=dummy_ids)
        hub_pooled_output = hub_layer(input_dict)["pooled_output"]
        encoder_outputs = encoder(input_dict)
        # Verify that hub_layer's pooled_output is the same as the output of next
        # sentence prediction's dense layer.
        pretrained_pooled_output = cls_heads[0].dense(
            (encoder_outputs["sequence_output"][:, 0, :]))
        self.assertAllClose(hub_pooled_output, pretrained_pooled_output)
        # But the pooled_output between encoder and hub_layer are not the same.
        encoder_pooled_output = encoder_outputs["pooled_output"]
        self.assertNotAllClose(hub_pooled_output, encoder_pooled_output)
 def build_model(self):
   if self._hub_module:
     encoder_network = utils.get_encoder_from_hub(self._hub_module)
   else:
     encoder_network = encoders.build_encoder(self.task_config.model.encoder)
   encoder_cfg = self.task_config.model.encoder.get()
   # Currently, we only support bert-style sentence prediction finetuning.
   return models.BertClassifier(
       network=encoder_network,
       num_classes=self.task_config.model.num_classes,
       initializer=tf.keras.initializers.TruncatedNormal(
           stddev=encoder_cfg.initializer_range),
       use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Beispiel #12
0
  def build_model(self):
    if self._hub_module:
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
    else:
      encoder_network = encoders.build_encoder(self.task_config.model.encoder)

    return models.BertTokenClassifier(
        network=encoder_network,
        num_classes=len(self.task_config.class_names),
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=self.task_config.model.head_initializer_range),
        dropout_rate=self.task_config.model.head_dropout,
        output='logits')
Beispiel #13
0
 def build_model(self):
   if self.task_config.hub_module_url and self.task_config.init_checkpoint:
     raise ValueError('At most one of `hub_module_url` and '
                      '`init_checkpoint` can be specified.')
   if self.task_config.hub_module_url:
     encoder_network = utils.get_encoder_from_hub(
         self.task_config.hub_module_url)
   else:
     encoder_network = encoders.build_encoder(self.task_config.model.encoder)
   encoder_cfg = self.task_config.model.encoder.get()
   return models.BertSpanLabeler(
       network=encoder_network,
       initializer=tf.keras.initializers.TruncatedNormal(
           stddev=encoder_cfg.initializer_range))
Beispiel #14
0
 def build_model(self, params=None):
     config = params or self.task_config.model
     encoder_cfg = config.encoder
     encoder_network = encoders.build_encoder(encoder_cfg)
     cls_heads = [
         layers.ClassificationHead(**cfg.as_dict())
         for cfg in config.cls_heads
     ] if config.cls_heads else []
     return models.BertPretrainerV2(
         mlm_activation=tf_utils.get_activation(config.mlm_activation),
         mlm_initializer=tf.keras.initializers.TruncatedNormal(
             stddev=config.mlm_initializer_range),
         encoder_network=encoder_network,
         classification_heads=cls_heads)
Beispiel #15
0
  def build_model(self):
    """Interface to build model. Refer to base_task.Task.build_model."""

    if self._hub_module:
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
    else:
      encoder_network = encoders.build_encoder(self.task_config.model.encoder)

    # Currently, we only supports bert-style dual encoder.
    return models.DualEncoder(
        network=encoder_network,
        max_seq_length=self.task_config.model.max_sequence_length,
        normalize=self.task_config.model.normalize,
        logit_scale=self.task_config.model.logit_scale,
        logit_margin=self.task_config.model.logit_margin,
        output='logits')
    def build_model(self) -> tf.keras.Model:
        if self.task_config.hub_module_url and self.task_config.init_checkpoint:
            raise ValueError('At most one of `hub_module_url` and '
                             '`init_checkpoint` can be specified.')
        if self.task_config.hub_module_url:
            encoder_network = utils.get_encoder_from_hub(
                self.task_config.hub_module_url)
        else:
            encoder_network = encoders.build_encoder(
                self.task_config.model.encoder)

        return models.BertClassifier(
            network=encoder_network,
            num_classes=len(self.task_config.class_names),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=self.task_config.model.head_initializer_range),
            dropout_rate=self.task_config.model.head_dropout)
 def build_model(self):
     if self.task_config.hub_module_url and self.task_config.init_checkpoint:
         raise ValueError('At most one of `hub_module_url` and '
                          '`init_checkpoint` can be specified.')
     if self.task_config.hub_module_url:
         encoder_network = utils.get_encoder_from_hub(
             self.task_config.hub_module_url)
     else:
         encoder_network = encoders.build_encoder(
             self.task_config.model.encoder)
     encoder_cfg = self.task_config.model.encoder.get()
     # Currently, we only support bert-style sentence prediction finetuning.
     return models.BertClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range),
         use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Beispiel #18
0
 def build_model(self):
     if self.task_config.hub_module_url and self.task_config.init_checkpoint:
         raise ValueError('At most one of `hub_module_url` and '
                          '`init_checkpoint` can be specified.')
     if self.task_config.hub_module_url:
         hub_module = hub.load(self.task_config.hub_module_url)
     else:
         hub_module = None
     if hub_module:
         encoder_network = utils.get_encoder_from_hub(hub_module)
     else:
         encoder_network = encoders.build_encoder(
             self.task_config.model.encoder)
     encoder_cfg = self.task_config.model.encoder.get()
     # Currently, we only supports bert-style question answering finetuning.
     return models.BertSpanLabeler(
         network=encoder_network,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range))
Beispiel #19
0
    def build_model(self):
        """Interface to build model. Refer to base_task.Task.build_model."""
        if self.task_config.hub_module_url and self.task_config.init_checkpoint:
            raise ValueError('At most one of `hub_module_url` and '
                             '`init_checkpoint` can be specified.')
        if self.task_config.hub_module_url:
            encoder_network = utils.get_encoder_from_hub(
                self.task_config.hub_module_url)
        else:
            encoder_network = encoders.build_encoder(
                self.task_config.model.encoder)

        # Currently, we only supports bert-style dual encoder.
        return models.DualEncoder(
            network=encoder_network,
            max_seq_length=self.task_config.model.max_sequence_length,
            normalize=self.task_config.model.normalize,
            logit_scale=self.task_config.model.logit_scale,
            logit_margin=self.task_config.model.logit_margin,
            output='logits')
Beispiel #20
0
    def test_task(self):
        # Saves a checkpoint.
        encoder = encoders.build_encoder(self._encoder_config)
        ckpt = tf.train.Checkpoint(encoder=encoder)
        saved_path = ckpt.save(self.get_temp_dir())

        config = tagging.TaggingConfig(
            init_checkpoint=saved_path,
            model=tagging.ModelConfig(encoder=self._encoder_config),
            train_data=self._train_data_config,
            class_names=["O", "B-PER", "I-PER"])
        task = tagging.TaggingTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
        task.initialize(model)
Beispiel #21
0
 def __init__(self, model_config: encoders.EncoderConfig,
              sequence_length: int, **kwargs):
     inputs = dict(token_ids=tf.keras.Input((sequence_length, ),
                                            dtype=tf.int32),
                   question_lengths=tf.keras.Input((), dtype=tf.int32))
     encoder = encoders.build_encoder(model_config)
     x = encoder(
         dict(input_word_ids=inputs['token_ids'],
              input_mask=tf.cast(inputs['token_ids'] > 0, tf.int32),
              input_type_ids=1 -
              tf.sequence_mask(inputs['question_lengths'], sequence_length,
                               tf.int32)))['sequence_output']
     logits = TriviaQaHead(
         model_config.get().intermediate_size,
         dropout_rate=model_config.get().dropout_rate,
         attention_dropout_rate=model_config.get().attention_dropout_rate)(
             dict(token_embeddings=x,
                  token_ids=inputs['token_ids'],
                  question_lengths=inputs['question_lengths']))
     super(TriviaQaModel, self).__init__(inputs, logits, **kwargs)
     self._encoder = encoder
Beispiel #22
0
 def build_model(self):
   if self.task_config.hub_module_url and self.task_config.init_checkpoint:
     raise ValueError('At most one of `hub_module_url` and '
                      '`init_checkpoint` can be specified.')
   if self.task_config.hub_module_url:
     encoder_network = utils.get_encoder_from_hub(
         self.task_config.hub_module_url)
   else:
     encoder_network = encoders.build_encoder(self.task_config.model.encoder)
   encoder_cfg = self.task_config.model.encoder.get()
   if self.task_config.model.encoder.type == 'xlnet':
     return models.XLNetClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.RandomNormal(
             stddev=encoder_cfg.initializer_range))
   else:
     return models.BertClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range),
         use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Beispiel #23
0
  def _build_pretrainer(self, pretrainer_cfg: bert.PretrainerConfig, name: str):
    """Builds pretrainer from config and encoder."""
    encoder = encoders.build_encoder(pretrainer_cfg.encoder)
    if pretrainer_cfg.cls_heads:
      cls_heads = [
          layers.ClassificationHead(**cfg.as_dict())
          for cfg in pretrainer_cfg.cls_heads
      ]
    else:
      cls_heads = []

    masked_lm = layers.MobileBertMaskedLM(
        embedding_table=encoder.get_embedding_table(),
        activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation),
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=pretrainer_cfg.mlm_initializer_range),
        name='cls/predictions')

    pretrainer = models.BertPretrainerV2(
        encoder_network=encoder,
        classification_heads=cls_heads,
        customized_masked_lm=masked_lm,
        name=name)
    return pretrainer
    def build_model(self, train_last_layer_only=False):
        """Modified version of official.nlp.tasks.tagging.build_model

        Allows to freeze the underlying bert encoder, such that only the dense
        layer is trained.
        """
        if self.task_config.hub_module_url and self.task_config.init_checkpoint:
            raise ValueError("At most one of `hub_module_url` and "
                             "`init_checkpoint` can be specified.")
        if self.task_config.hub_module_url:
            encoder_network = utils.get_encoder_from_hub(
                self.task_config.hub_module_url)
        else:
            encoder_network = encoders.build_encoder(
                self.task_config.model.encoder)
        encoder_network.trainable = not train_last_layer_only

        return models.BertTokenClassifier(
            network=encoder_network,
            num_classes=len(self.task_config.class_names),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=self.task_config.model.head_initializer_range),
            dropout_rate=self.task_config.model.head_dropout,
            output="logits")
 def _build_encoder(self, encoder_cfg):
   return encoders.build_encoder(encoder_cfg)
Beispiel #26
0
def _create_model(
    *,
    bert_config: Optional[configs.BertConfig] = None,
    encoder_config: Optional[encoders.EncoderConfig] = None,
    with_mlm: bool,
) -> Tuple[tf.keras.Model, tf.keras.Model]:
    """Creates the model to export and the model to restore the checkpoint.

  Args:
    bert_config: A legacy `BertConfig` to create a `BertEncoder` object.
      Exactly one of encoder_config and bert_config must be set.
    encoder_config: An `EncoderConfig` to create an encoder of the configured
      type (`BertEncoder` or other).
    with_mlm: A bool to control the second component of the result.
      If True, will create a `BertPretrainerV2` object; otherwise, will
      create a `BertEncoder` object.

  Returns:
    A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
    object or `BertEncoder` object depending on the value of `with_mlm`
    argument, which contains the first model and will be used for restoring
    weights from the checkpoint.
  """
    if (bert_config is not None) == (encoder_config is not None):
        raise ValueError("Exactly one of `bert_config` and `encoder_config` "
                         "can be specified, but got %s and %s" %
                         (bert_config, encoder_config))

    if bert_config is not None:
        encoder = get_bert_encoder(bert_config)
    else:
        encoder = encoders.build_encoder(encoder_config)

    # Convert from list of named inputs to dict of inputs keyed by name.
    # Only the latter accepts a dict of inputs after restoring from SavedModel.
    encoder_inputs_dict = {x.name: x for x in encoder.inputs}
    encoder_output_dict = encoder(encoder_inputs_dict)
    # For interchangeability with other text representations,
    # add "default" as an alias for BERT's whole-input reptesentations.
    encoder_output_dict["default"] = encoder_output_dict["pooled_output"]
    core_model = tf.keras.Model(inputs=encoder_inputs_dict,
                                outputs=encoder_output_dict)

    if with_mlm:
        if bert_config is not None:
            hidden_act = bert_config.hidden_act
        else:
            assert encoder_config is not None
            hidden_act = encoder_config.get().hidden_activation

        pretrainer = models.BertPretrainerV2(
            encoder_network=encoder,
            mlm_activation=tf_utils.get_activation(hidden_act))

        pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
        pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
        mlm_model = tf.keras.Model(inputs=pretrainer_inputs_dict,
                                   outputs=pretrainer_output_dict)
        # Set `_auto_track_sub_layers` to False, so that the additional weights
        # from `mlm` sub-object will not be included in the core model.
        # TODO(b/169210253): Use a public API when available.
        core_model._auto_track_sub_layers = False  # pylint: disable=protected-access
        core_model.mlm = mlm_model
        return core_model, pretrainer
    else:
        return core_model, encoder
Beispiel #27
0
    def prepare_config(self, teacher_block_num, student_block_num,
                       transfer_teacher_layers):
        # using small model for testing
        task_config = distillation.BertDistillationTaskConfig(
            teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=teacher_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='gelu'),
            student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=student_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='relu'),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10))

        # set only 1 step for each stage
        progressive_config = distillation.BertDistillationProgressiveConfig()
        progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
            transfer_teacher_layers)
        progressive_config.layer_wise_distill_config.num_steps = 1
        progressive_config.pretrain_distill_config.num_steps = 1

        optimization_config = optimization.OptimizationConfig(
            optimizer=optimization.OptimizerConfig(
                type='lamb',
                lamb=optimization.LAMBConfig(weight_decay_rate=0.0001,
                                             exclude_from_weight_decay=[
                                                 'LayerNorm', 'layer_norm',
                                                 'bias', 'no_norm'
                                             ])),
            learning_rate=optimization.LrConfig(
                type='polynomial',
                polynomial=optimization.PolynomialLrConfig(
                    initial_learning_rate=1.5e-3,
                    decay_steps=10000,
                    end_learning_rate=1.5e-3)),
            warmup=optimization.WarmupConfig(
                type='linear',
                linear=optimization.LinearWarmupConfig(
                    warmup_learning_rate=0)))

        exp_config = cfg.ExperimentConfig(
            task=task_config,
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(
                progressive=progressive_config,
                optimizer_config=optimization_config))

        # Create a teacher model checkpoint.
        teacher_encoder = encoders.build_encoder(
            task_config.teacher_model.encoder)
        pretrainer_config = task_config.teacher_model
        if pretrainer_config.cls_heads:
            teacher_cls_heads = [
                layers.ClassificationHead(**cfg.as_dict())
                for cfg in pretrainer_config.cls_heads
            ]
        else:
            teacher_cls_heads = []

        masked_lm = layers.MobileBertMaskedLM(
            embedding_table=teacher_encoder.get_embedding_table(),
            activation=tf_utils.get_activation(
                pretrainer_config.mlm_activation),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=pretrainer_config.mlm_initializer_range),
            name='cls/predictions')
        teacher_pretrainer = models.BertPretrainerV2(
            encoder_network=teacher_encoder,
            classification_heads=teacher_cls_heads,
            customized_masked_lm=masked_lm)

        # The model variables will be created after the forward call.
        _ = teacher_pretrainer(teacher_pretrainer.inputs)
        teacher_pretrainer_ckpt = tf.train.Checkpoint(
            **teacher_pretrainer.checkpoint_items)
        teacher_ckpt_path = os.path.join(self.get_temp_dir(),
                                         'teacher_model.ckpt')
        teacher_pretrainer_ckpt.save(teacher_ckpt_path)
        exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()

        return exp_config