示例#1
0
    def __init__(self, albert_config, max_seq_length, init_checkpoint, start_n_top, end_n_top, dropout=0.1, **kwargs):
        super(ALBertQAModel, self).__init__(**kwargs)
        self.albert_config = copy.deepcopy(albert_config)
        self.initializer = tf.keras.initializers.TruncatedNormal(
            stddev=self.albert_config.initializer_range)
        float_type = tf.float32

        input_word_ids = tf.keras.layers.Input(
            shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
        input_mask = tf.keras.layers.Input(
            shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
        input_type_ids = tf.keras.layers.Input(
            shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')

        albert_layer = AlbertModel(config=albert_config, float_type=float_type)

        _, sequence_output = albert_layer(
            input_word_ids, input_mask, input_type_ids)

        self.albert_model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids],
                                           outputs=[sequence_output])
        if init_checkpoint != None:
            self.albert_model.load_weights(init_checkpoint)

        self.qalayer = ALBertQALayer(self.albert_config.hidden_size, start_n_top, end_n_top,
                                     self.initializer, dropout)
示例#2
0
    def __init__(self,
                 albert_config,
                 max_seq_length,
                 init_checkpoint,
                 start_n_top,
                 end_n_top,
                 dropout=0.1,
                 **kwargs):
        super(ALBertQAModel_v2, self).__init__(**kwargs)
        self.albert_config = copy.deepcopy(albert_config)
        self.initializer = tf.keras.initializers.TruncatedNormal(
            stddev=self.albert_config.initializer_range)
        float_type = tf.float32

        input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                               dtype=tf.int32,
                                               name='input_word_ids')
        input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_mask')
        input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                               dtype=tf.int32,
                                               name='input_type_ids')

        albert_layer = AlbertModel(config=albert_config, float_type=float_type)

        pooled_output, sequence_output = albert_layer(input_word_ids,
                                                      input_mask,
                                                      input_type_ids)

        bilstm = tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(512, return_sequences=True))(sequence_output)
        bilstm_self = attention.SeqSelfAttention(
            attention_activation='sigmoid')(bilstm)

        bigru = tf.keras.layers.Bidirectional(
            tf.keras.layers.GRU(256, return_sequences=True))(bilstm)
        bigru_self = attention.SeqSelfAttention(
            attention_activation='sigmoid')(bigru)

        conc = tf.keras.layers.Concatenate()(
            [bilstm_self,
             bigru_self])  #([bilstm, bigru])#([bilstm_self, bigru_self])

        self.albert_model = tf.keras.Model(
            inputs=[input_word_ids, input_mask, input_type_ids],
            outputs=[pooled_output, sequence_output])
        if init_checkpoint != None:
            print('init_checkpoint loading ...')
            self.albert_model.load_weights(init_checkpoint)

        self.albert_lstm = tf.keras.Model(
            inputs=[input_word_ids, input_mask, input_type_ids], outputs=conc)

        self.qalayer = ALBertQALayer(conc.shape[-1], start_n_top, end_n_top,
                                     self.initializer, dropout)
def get_fine_tune_model(tinybert_config,
                   albert_config,
                   seq_length):
    input_word_ids = tf.keras.layers.Input(
        shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
    input_mask = tf.keras.layers.Input(
        shape=(seq_length,), name='input_mask', dtype=tf.int32)
    input_type_ids = tf.keras.layers.Input(
        shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
    
    ##bert_base teacher
    float_type = tf.float32
    albert_encoder = "albert_model"
    albert_layer = AlbertModel(config=albert_config, float_type=float_type, name=albert_encoder, trainable=False)
    albert_pooled_output, albert_sequence_output, albert_attention_scores, albert_embed_tensor = albert_layer(input_word_ids, input_mask,input_type_ids)
    ##tinybert student
    float_type = tf.float32
    tinybert_encoder = "tinybert_model"
    tinybert_layer = TinybertModel(config=tinybert_config, float_type=float_type, name=tinybert_encoder)
    tinybert_pooled_output, tinybert_sequence_output, tinybert_attention_scores, tinybert_embed_tensor = tinybert_layer(input_word_ids, input_mask, input_type_ids)
    
    albert_teacher = tf.keras.Model(
        inputs = [input_word_ids, input_mask, input_type_ids],
        outputs = [albert_pooled_output] + albert_sequence_output + albert_attention_scores + [albert_embed_tensor],
        name = 'albert'
    )
    
    tinybert_student = tf.keras.Model(
        inputs = [input_word_ids, input_mask, input_type_ids],
        outputs = [tinybert_pooled_output] + tinybert_sequence_output + tinybert_attention_scores + [tinybert_embed_tensor],
        name = 'tinybert'
    )
    
    albert_teacher_outputs = albert_teacher([input_word_ids, input_mask, input_type_ids])
    tinybert_student_outputs = tinybert_student([input_word_ids, input_mask, input_type_ids])
    
    teacher_pooled_output = albert_teacher_outputs[0]
    teacher_sequence_output = albert_teacher_outputs[1:13]
    teacher_atten_score = albert_teacher_outputs[13:25]
    teacher_embed_tensor = albert_teacher_outputs[25]
    student_pooled_output = tinybert_student_outputs[0]
    student_sequence_output = tinybert_student_outputs[1:7]
    student_atten_score = tinybert_student_outputs[7:13]
    student_embed_tensor = tinybert_student_outputs[13]
    
    tinybert_loss_layer = TinybertLossLayer(albert_config, tinybert_config, initializer=None, name="dislit")
    distil_loss = tinybert_loss_layer([albert_embed_tensor,tinybert_embed_tensor, teacher_pooled_output, student_pooled_output,
                                       teacher_sequence_output, student_sequence_output, teacher_atten_score, student_atten_score,])
    
    return tf.keras.Model(
        inputs={
            'input_word_ids': input_word_ids,
            'input_mask': input_mask,
            'input_type_ids': input_type_ids,},
        outputs = [distil_loss],
        name = 'fine_tune_model'), albert_teacher, tinybert_student
示例#4
0
def _create_albert_model(cfg):
  """Creates a BERT keras core model from BERT configuration.
  Args:
    cfg: A `BertConfig` to create the core model.
  Returns:
    A keras model.
  """
  max_seq_length = 256
  albert_layer = AlbertModel(config=cfg, float_type=tf.float32)
  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
  pooled_output, sequence_output = albert_layer(input_word_ids, input_mask,
                                                input_type_ids)
  albert_model = tf.keras.Model(
  inputs=[input_word_ids, input_mask, input_type_ids],
  outputs=[pooled_output, sequence_output])
  
  return albert_model
示例#5
0
def get_model(albert_config, max_seq_length, num_labels, init_checkpoint,
              learning_rate, num_train_steps, num_warmup_steps,
              loss_multiplier):
    """Returns keras fuctional model"""
    float_type = tf.float32
    hidden_dropout_prob = FLAGS.classifier_dropout  # as per original code relased
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_type_ids')

    albert_layer = AlbertModel(config=albert_config, float_type=float_type)

    pooled_output, _ = albert_layer(input_word_ids, input_mask, input_type_ids)

    albert_model = tf.keras.Model(
        inputs=[input_word_ids, input_mask, input_type_ids],
        outputs=[pooled_output])

    albert_model.load_weights(init_checkpoint)

    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=albert_config.initializer_range)

    output = tf.keras.layers.Dropout(rate=hidden_dropout_prob)(pooled_output)

    output = tf.keras.layers.Dense(num_labels,
                                   kernel_initializer=initializer,
                                   name='output',
                                   dtype=float_type)(output)
    model = tf.keras.Model(inputs={
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'input_type_ids': input_type_ids
    },
                           outputs=output)

    learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=learning_rate,
        decay_steps=num_train_steps,
        end_learning_rate=0.0)
    if num_warmup_steps:
        learning_rate_fn = WarmUp(initial_learning_rate=learning_rate,
                                  decay_schedule_fn=learning_rate_fn,
                                  warmup_steps=num_warmup_steps)
    if FLAGS.optimizer == "LAMB":
        optimizer_fn = LAMB
    else:
        optimizer_fn = AdamWeightDecay

    optimizer = optimizer_fn(learning_rate=learning_rate_fn,
                             weight_decay_rate=FLAGS.weight_decay,
                             beta_1=0.9,
                             beta_2=0.999,
                             epsilon=FLAGS.adam_epsilon,
                             exclude_from_weight_decay=['layer_norm', 'bias'])

    if FLAGS.task_name.lower() == 'sts':
        loss_fct = tf.keras.losses.MeanSquaredError()
        model.compile(optimizer=optimizer, loss=loss_fct, metrics=['mse'])
    else:
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True)
        model.compile(optimizer=optimizer, loss=loss_fct, metrics=['accuracy'])

    return model
示例#6
0
def get_model_v1(albert_config, max_seq_length, init_checkpoint, learning_rate,
                 num_train_steps, num_warmup_steps):
    """Returns keras fuctional model"""
    float_type = tf.float32
    # hidden_dropout_prob = 0.9 # as per original code relased
    unique_ids = tf.keras.layers.Input(
        shape=(1,), dtype=tf.int32, name='unique_ids')
    input_word_ids = tf.keras.layers.Input(
        shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
    input_mask = tf.keras.layers.Input(
        shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
    input_type_ids = tf.keras.layers.Input(
        shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')

    albert_layer = AlbertModel(config=albert_config, float_type=float_type)

    _, sequence_output = albert_layer(
        input_word_ids, input_mask, input_type_ids)

    albert_model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids],
                                  outputs=[sequence_output])

    if init_checkpoint != None:
        albert_model.load_weights(init_checkpoint)

    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=albert_config.initializer_range)

    squad_logits_layer = ALBertSquadLogitsLayer(
        initializer=initializer, float_type=float_type, name='squad_logits')

    start_logits, end_logits = squad_logits_layer(sequence_output)

    squad_model = tf.keras.Model(
        inputs={
            'unique_ids': unique_ids,
            'input_ids': input_word_ids,
            'input_mask': input_mask,
            'segment_ids': input_type_ids,
        },
        outputs=[unique_ids, start_logits, end_logits],
        name='squad_model')

    learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=learning_rate,
                                                                     decay_steps=num_train_steps, end_learning_rate=0.0)
    if num_warmup_steps:
        learning_rate_fn = WarmUp(initial_learning_rate=learning_rate,
                                  decay_schedule_fn=learning_rate_fn,
                                  warmup_steps=num_warmup_steps)

    if FLAGS.optimizer == "LAMB":
        optimizer_fn = LAMB
    else:
        optimizer_fn = AdamWeightDecay

    optimizer = optimizer_fn(
        learning_rate=learning_rate_fn,
        weight_decay_rate=FLAGS.weight_decay,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=FLAGS.adam_epsilon,
        exclude_from_weight_decay=['layer_norm', 'bias'])

    squad_model.optimizer = optimizer

    return squad_model
def train_tinybert_model(tinybert_config,
                   albert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
    input_word_ids = tf.keras.layers.Input(
        shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
    input_mask = tf.keras.layers.Input(
        shape=(seq_length,), name='input_mask', dtype=tf.int32)
    input_type_ids = tf.keras.layers.Input(
        shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
    masked_lm_positions = tf.keras.layers.Input(
        shape=(max_predictions_per_seq,),
        name='masked_lm_positions',
        dtype=tf.int32)
    masked_lm_weights = tf.keras.layers.Input(
        shape=(max_predictions_per_seq,),
        name='masked_lm_weights',
        dtype=tf.int32)
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1,), name='next_sentence_labels', dtype=tf.int32)
    masked_lm_ids = tf.keras.layers.Input(
        shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
    
    ##bert_base teacher
    float_type = tf.float32
    albert_encoder = "albert_model"
    albert_layer = AlbertModel(config=albert_config, float_type=float_type, name=albert_encoder, trainable=False)
    albert_pooled_output, albert_sequence_output, albert_attention_scores, albert_embed_tensor = albert_layer(input_word_ids, input_mask,input_type_ids)
    ##tinybert student
    float_type = tf.float32
    tinybert_encoder = "tinybert_model"
    tinybert_layer = TinybertModel(config=tinybert_config, float_type=float_type, name=tinybert_encoder)
    tinybert_pooled_output, tinybert_sequence_output, tinybert_attention_scores, tinybert_embed_tensor = tinybert_layer(input_word_ids, input_mask, input_type_ids)
    
    albert_teacher = tf.keras.Model(
        inputs = [input_word_ids, input_mask, input_type_ids],
        outputs = [albert_pooled_output] + albert_sequence_output + albert_attention_scores + [albert_embed_tensor],
        name = 'albert'
    )
    
    tinybert_student = tf.keras.Model(
        inputs = [input_word_ids, input_mask, input_type_ids],
        outputs = [tinybert_pooled_output] + tinybert_sequence_output + tinybert_attention_scores + [tinybert_embed_tensor],
        name = 'tinybert'
    )
    
    albert_teacher_outputs = albert_teacher([input_word_ids, input_mask, input_type_ids])
    tinybert_student_outputs = tinybert_student([input_word_ids, input_mask, input_type_ids])
    
    teacher_pooled_output = albert_teacher_outputs[0]
    teacher_sequence_output = albert_teacher_outputs[1:13]
    teacher_atten_score = albert_teacher_outputs[13:25]
    teacher_embed_tensor = albert_teacher_outputs[25]
    student_pooled_output = tinybert_student_outputs[0]
    student_sequence_output = tinybert_student_outputs[1:7]
    student_atten_score = tinybert_student_outputs[7:13]
    student_embed_tensor = tinybert_student_outputs[13]
    # dislit loss
    
    tinybert_pretrain_layer = TinyBertPretrainLayer(tinybert_config,
                                                    tinybert_layer.embedding_lookup.embeddings,
                                                    initializer=initializer,
                                                    name='tinybert_cls')
    tinybert_lm_output, tinybert_sentence_output, tinybert_logits = tinybert_pretrain_layer(student_pooled_output, student_sequence_output[-1], masked_lm_positions)
   
    albert_pretrain_layer = ALBertPretrainLayer(albert_config,
                                                albert_layer.embedding_lookup.embeddings,
                                                initializer=initializer,
                                                name='albert_cls',
                                                trainable=False)
    albert_lm_output, albert_sentence_output, albert_logits = albert_pretrain_layer(teacher_pooled_output, teacher_sequence_output[-1], masked_lm_positions)
    
    tinybert_loss_layer = TinybertLossLayer(albert_config, tinybert_config, initializer=initializer, name="dislit")
    
    # loss_output = tinybert_loss_layer(albert_embedding_table=albert_embed_tensor,
    #                                   tinybert_embedding_table=tinybert_embed_tensor,
    #                                   albert_pooled_output=teacher_pooled_output,
    #                                   tinybert_pooled_output=student_pooled_output,
    #                                   albert_seq_output=teacher_sequence_output,
    #                                   tinybert_seq_output=student_sequence_output,
    #                                   albert_atten_score=teacher_atten_score,
    #                                   tinybert_atten_score=student_atten_score,
    #                                   albert_logits=albert_logits,
    #                                   tinybert_logits=tinybert_logits,
    #                                   lm_label_ids=masked_lm_ids,
    #                                   lm_label_weights=masked_lm_weights)
    distil_loss = tinybert_loss_layer([albert_embed_tensor,tinybert_embed_tensor, teacher_pooled_output, student_pooled_output,
                                       teacher_sequence_output, student_sequence_output, teacher_atten_score, student_atten_score,
                                       albert_lm_output, tinybert_lm_output, masked_lm_ids, masked_lm_weights])
    
    # pretrain_loss
    tinybert_pretrain_loss_metrics_layer = TinyBertPretrainLossAndMetricLayer(tinybert_config, name="metric")
    
    pretrain_loss = tinybert_pretrain_loss_metrics_layer(tinybert_lm_output, tinybert_sentence_output, masked_lm_ids,
                                                masked_lm_weights, next_sentence_labels)    
    
    distil_loss
    return tf.keras.Model(
        inputs={
            'input_word_ids': input_word_ids,
            'input_mask': input_mask,
            'input_type_ids': input_type_ids,
            'masked_lm_positions': masked_lm_positions,
            'masked_lm_ids': masked_lm_ids,
            'masked_lm_weights': masked_lm_weights,
            'next_sentence_labels': next_sentence_labels},
        outputs = [distil_loss, pretrain_loss],
        name = 'train_model'), albert_teacher, tinybert_student
示例#8
0
def main(_):

    tfhub_model_path = FLAGS.tf_hub_path
    max_seq_length = 512
    float_type = tf.float32

    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_type_ids')

    if FLAGS.version == 2:
        albert_config = AlbertConfig.from_json_file(
            os.path.join(tfhub_model_path, "assets", "albert_config.json"))
    else:
        albert_config = AlbertConfig.from_json_file(
            os.path.join("model_configs", FLAGS.model, "config.json"))

    tags = []

    stock_values = {}

    with tf.Graph().as_default():
        sm = tf.compat.v2.saved_model.load(tfhub_model_path, tags=tags)
        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            stock_values = {
                v.name.split(":")[0]: v.read_value()
                for v in sm.variables
            }
            stock_values = sess.run(stock_values)

    loaded_weights = set()
    skip_count = 0
    weight_value_tuples = []
    skipped_weight_value_tuples = []

    if FLAGS.model_type == "albert_encoder":
        albert_layer = AlbertModel(config=albert_config, float_type=float_type)

        pooled_output, sequence_output = albert_layer(input_word_ids,
                                                      input_mask,
                                                      input_type_ids)
        albert_model = tf.keras.Model(
            inputs=[input_word_ids, input_mask, input_type_ids],
            outputs=[pooled_output, sequence_output])
        albert_params = albert_model.weights
        param_values = tf.keras.backend.batch_get_value(albert_model.weights)
    else:
        albert_full_model, _ = pretrain_model(albert_config,
                                              max_seq_length,
                                              max_predictions_per_seq=20)
        albert_layer = albert_full_model.get_layer("albert_model")
        albert_params = albert_full_model.weights
        param_values = tf.keras.backend.batch_get_value(
            albert_full_model.weights)

    for ndx, (param_value, param) in enumerate(zip(param_values,
                                                   albert_params)):
        stock_name = weight_map[param.name]

        if stock_name in stock_values:
            ckpt_value = stock_values[stock_name]

            if param_value.shape != ckpt_value.shape:
                print(
                    "loader: Skipping weight:[{}] as the weight shape:[{}] is not compatible "
                    "with the checkpoint:[{}] shape:{}".format(
                        param.name, param.shape, stock_name, ckpt_value.shape))
                skipped_weight_value_tuples.append((param, ckpt_value))
                continue

            weight_value_tuples.append((param, ckpt_value))
            loaded_weights.add(stock_name)
        else:
            print("loader: No value for:[{}], i.e.:[{}] in:[{}]".format(
                param.name, stock_name, tfhub_model_path))
            skip_count += 1
    tf.keras.backend.batch_set_value(weight_value_tuples)

    print("Done loading {} ALBERT weights from: {} into {} (prefix:{}). "
          "Count of weights not found in the checkpoint was: [{}]. "
          "Count of weights with mismatched shape: [{}]".format(
              len(weight_value_tuples), tfhub_model_path, albert_layer,
              "albert", skip_count, len(skipped_weight_value_tuples)))
    print(
        "Unused weights from saved model:", "\n\t" + "\n\t".join(
            sorted(set(stock_values.keys()).difference(loaded_weights))))

    if FLAGS.model_type == "albert_encoder":
        albert_model.save_weights(f"{tfhub_model_path}/tf2_model.h5")
    else:
        albert_full_model.save_weights(f"{tfhub_model_path}/tf2_model_full.h5")
示例#9
0
def pretrain_model(albert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
    """Returns model to be used for pre-training.
  Args:
      albert_config: Configuration that defines the core ALBERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
      initializer: Initializer for weights in BertPretrainLayer.
  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """

    input_word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                           name='input_word_ids',
                                           dtype=tf.int32)
    input_mask = tf.keras.layers.Input(shape=(seq_length, ),
                                       name='input_mask',
                                       dtype=tf.int32)
    input_type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                           name='input_type_ids',
                                           dtype=tf.int32)
    masked_lm_positions = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_positions',
        dtype=tf.int32)
    masked_lm_weights = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_weights',
        dtype=tf.int32)
    next_sentence_labels = tf.keras.layers.Input(shape=(1, ),
                                                 name='next_sentence_labels',
                                                 dtype=tf.int32)
    masked_lm_ids = tf.keras.layers.Input(shape=(max_predictions_per_seq, ),
                                          name='masked_lm_ids',
                                          dtype=tf.int32)

    float_type = tf.float32
    albert_encoder = "albert_model"
    albert_layer = AlbertModel(config=albert_config,
                               float_type=float_type,
                               name=albert_encoder)
    pooled_output, sequence_output = albert_layer(input_word_ids, input_mask,
                                                  input_type_ids)
    albert_submodel = tf.keras.Model(
        inputs=[input_word_ids, input_mask, input_type_ids],
        outputs=[pooled_output, sequence_output])

    pooled_output = albert_submodel.outputs[0]
    sequence_output = albert_submodel.outputs[1]

    pretrain_layer = ALBertPretrainLayer(
        albert_config,
        albert_submodel.get_layer(albert_encoder),
        initializer=initializer,
        name='cls')
    lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output,
                                                masked_lm_positions)

    pretrain_loss_layer = ALBertPretrainLossAndMetricLayer(albert_config)
    output_loss = pretrain_loss_layer(lm_output, sentence_output,
                                      masked_lm_ids, masked_lm_weights,
                                      next_sentence_labels)

    return tf.keras.Model(inputs={
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'input_type_ids': input_type_ids,
        'masked_lm_positions': masked_lm_positions,
        'masked_lm_ids': masked_lm_ids,
        'masked_lm_weights': masked_lm_weights,
        'next_sentence_labels': next_sentence_labels,
    },
                          outputs=output_loss), albert_submodel