def pretrain_model(bert_config, seq_length, max_predictions_per_seq, initializer=None, use_next_sentence_label=True, return_core_pretrainer_model=False): """Returns model to be used for pre-training. Args: bert_config: Configuration that defines the core BERT model. seq_length: Maximum sequence length of the training data. max_predictions_per_seq: Maximum number of tokens in sequence to mask out and use for pretraining. initializer: Initializer for weights in BertPretrainer. use_next_sentence_label: Whether to use the next sentence label. return_core_pretrainer_model: Whether to also return the `BertPretrainer` object. Returns: A Tuple of (1) Pretraining model, (2) core BERT submodel from which to save weights after pretraining, and (3) optional core `BertPretrainer` object if argument `return_core_pretrainer_model` is True. """ input_word_ids = tf.keras.layers.Input( shape=(seq_length,), name='input_word_ids', dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(seq_length,), name='input_mask', dtype=tf.int32) input_type_ids = tf.keras.layers.Input( shape=(seq_length,), name='input_type_ids', dtype=tf.int32) masked_lm_positions = tf.keras.layers.Input( shape=(max_predictions_per_seq,), name='masked_lm_positions', dtype=tf.int32) masked_lm_ids = tf.keras.layers.Input( shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32) masked_lm_weights = tf.keras.layers.Input( shape=(max_predictions_per_seq,), name='masked_lm_weights', dtype=tf.int32) if use_next_sentence_label: next_sentence_labels = tf.keras.layers.Input( shape=(1,), name='next_sentence_labels', dtype=tf.int32) else: next_sentence_labels = None transformer_encoder = get_transformer_encoder(bert_config, seq_length) if initializer is None: initializer = tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range) pretrainer_model = models.BertPretrainer( network=transformer_encoder, embedding_table=transformer_encoder.get_embedding_table(), num_classes=2, # The next sentence prediction label has two classes. activation=tf_utils.get_activation(bert_config.hidden_act), num_token_predictions=max_predictions_per_seq, initializer=initializer, output='logits') outputs = pretrainer_model( [input_word_ids, input_mask, input_type_ids, masked_lm_positions]) lm_output = outputs['masked_lm'] sentence_output = outputs['classification'] pretrain_loss_layer = BertPretrainLossAndMetricLayer( vocab_size=bert_config.vocab_size) output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, masked_lm_weights, next_sentence_labels) inputs = { 'input_word_ids': input_word_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids, 'masked_lm_positions': masked_lm_positions, 'masked_lm_ids': masked_lm_ids, 'masked_lm_weights': masked_lm_weights, } if use_next_sentence_label: inputs['next_sentence_labels'] = next_sentence_labels keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss) if return_core_pretrainer_model: return keras_model, transformer_encoder, pretrainer_model else: return keras_model, transformer_encoder
def pretrain_model(bert_config, seq_length, max_predictions_per_seq, initializer=None): """Returns model to be used for pre-training. Args: bert_config: Configuration that defines the core BERT model. seq_length: Maximum sequence length of the training data. max_predictions_per_seq: Maximum number of tokens in sequence to mask out and use for pretraining. initializer: Initializer for weights in BertPretrainer. Returns: Pretraining model as well as core BERT submodel from which to save weights after pretraining. """ input_word_ids = tf.keras.layers.Input(shape=(seq_length, ), name='input_word_ids', dtype=tf.int32) input_mask = tf.keras.layers.Input(shape=(seq_length, ), name='input_mask', dtype=tf.int32) input_type_ids = tf.keras.layers.Input(shape=(seq_length, ), name='input_type_ids', dtype=tf.int32) masked_lm_positions = tf.keras.layers.Input( shape=(max_predictions_per_seq, ), name='masked_lm_positions', dtype=tf.int32) masked_lm_ids = tf.keras.layers.Input(shape=(max_predictions_per_seq, ), name='masked_lm_ids', dtype=tf.int32) masked_lm_weights = tf.keras.layers.Input( shape=(max_predictions_per_seq, ), name='masked_lm_weights', dtype=tf.int32) next_sentence_labels = tf.keras.layers.Input(shape=(1, ), name='next_sentence_labels', dtype=tf.int32) transformer_encoder = get_transformer_encoder(bert_config, seq_length) if initializer is None: initializer = tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range) pretrainer_model = models.BertPretrainer( network=transformer_encoder, num_classes=2, # The next sentence prediction label has two classes. num_token_predictions=max_predictions_per_seq, initializer=initializer, output='predictions') lm_output, sentence_output = pretrainer_model( [input_word_ids, input_mask, input_type_ids, masked_lm_positions]) pretrain_loss_layer = BertPretrainLossAndMetricLayer( vocab_size=bert_config.vocab_size) output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, masked_lm_weights, next_sentence_labels) keras_model = tf.keras.Model(inputs={ 'input_word_ids': input_word_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids, 'masked_lm_positions': masked_lm_positions, 'masked_lm_ids': masked_lm_ids, 'masked_lm_weights': masked_lm_weights, 'next_sentence_labels': next_sentence_labels, }, outputs=output_loss) return keras_model, transformer_encoder