Ejemplo n.º 1
0
def create_bert_model(bert_config):
    """Creates a BERT keras core model from BERT configuration.

  Args:
    bert_config: A BertConfig` to create the core model.
  Returns:
    A keras model.
  """
    max_seq_length = bert_config.max_position_embeddings

    # Adds input layers just as placeholders.
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name="input_word_ids")
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name="input_mask")
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name="input_type_ids")
    core_model = modeling.get_bert_model(input_word_ids,
                                         input_mask,
                                         input_type_ids,
                                         config=bert_config,
                                         name="bert_model",
                                         float_type=tf.float32)
    return core_model
Ejemplo n.º 2
0
def coqa_model_bert_transformer(config,
                                max_seq_length,
                                max_answer_length,
                                float_type,
                                training=False,
                                initializer=None):
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')
    decode_ids = tf.keras.layers.Input(shape=(max_answer_length, ),
                                       dtype=tf.int32,
                                       name='decode_ids')
    decode_mask = tf.keras.layers.Input(shape=(max_answer_length, ),
                                        dtype=tf.int32,
                                        name='decode_mask')

    bert_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Model` only uses the sequence_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    sequence_output = bert_model.outputs[1]

    if initializer is None:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=config.initializer_range)

    #
    coqa_layer = coqalayers.SimpleTransformerDecoder(
        config=config, name='simple_transformer_decoder')

    final_dists = coqa_layer(sequence_output, input_mask, decode_ids,
                             decode_mask)

    coqa = tf.keras.Model(inputs=({
        'unique_ids': unique_ids,
        'input_word_ids': input_word_ids,
        'input_type_ids': input_type_ids,
        'input_mask': input_mask,
        'decode_ids': decode_ids,
        'decode_mask': decode_mask
    }, ),
                          outputs=[unique_ids, final_dists])

    return coqa, bert_model
Ejemplo n.º 3
0
def coqa_model_bert_span(config,
                         max_seq_length,
                         max_answer_length,
                         float_type,
                         training=False,
                         initializer=None):
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')

    bert_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Model` only uses the sequence_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    pooled_output = bert_model.outputs[0]
    sequence_output = bert_model.outputs[1]

    yesnounknown_logits_layer = BertYNULogitsLayer(initializer=initializer,
                                                   float_type=float_type,
                                                   name='yesnounknown_logits')

    ynu_logits = yesnounknown_logits_layer(pooled_output)

    span_logits_layer = BertSpanLogitsLayer(initializer=initializer,
                                            float_type=float_type,
                                            name='bert_span_logits')

    start_logits, end_logits = span_logits_layer(sequence_output)

    coqa = tf.keras.Model(
        inputs=({
            'unique_ids': unique_ids,
            'input_ids': input_word_ids,
            'segment_ids': input_type_ids,
            'input_mask': input_mask
        }),
        outputs=[unique_ids, start_logits, end_logits, ynu_logits],
        name='coqa_model')

    return coqa, bert_model
Ejemplo n.º 4
0
def coqa_model(bert_config, max_seq_length, float_type, initializer=None):
    """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
    float_type: tf.dtype, tf.float32 or tf.bfloat16.
    initializer: Initializer for weights in BertSquadLogitsLayer.

  Returns:
    Two tensors, start logits and end logits, [batch x sequence length].
  """
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')

    core_model = modeling.get_bert_model(input_word_ids,
                                         input_mask,
                                         input_type_ids,
                                         config=bert_config,
                                         name='bert_model',
                                         float_type=float_type)

    # `BertSquadModel` only uses the sequnce_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    sequence_output = core_model.outputs[1]

    if initializer is None:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range)
    coqa_logits_layer = BertCoqaLogitsLayer(initializer=initializer,
                                            float_type=float_type,
                                            name='coqa_logits')
    start_logits, end_logits = coqa_logits_layer(sequence_output)

    coqa = tf.keras.Model(inputs={
        'unique_ids': unique_ids,
        'input_ids': input_word_ids,
        'input_mask': input_mask,
        'segment_ids': input_type_ids,
    },
                          outputs=[unique_ids, start_logits, end_logits],
                          name='coqa_model')
    return coqa, core_model
Ejemplo n.º 5
0
def create_bert_model(bert_config: bert_modeling.BertConfig):
  """Creates a BERT keras core model from BERT configuration.

  Args:
    bert_config: A BertConfig` to create the core model.

  Returns:
    A keras model.
  """
  # Adds input layers just as placeholders.
  input_word_ids = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name="input_word_ids")
  input_mask = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name="input_mask")
  input_type_ids = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name="input_type_ids")
  return bert_modeling.get_bert_model(
      input_word_ids,
      input_mask,
      input_type_ids,
      config=bert_config,
      name="bert_model",
      float_type=tf.float32)
Ejemplo n.º 6
0
def coqa_model_2heads(config,
                      max_seq_length,
                      max_answer_length,
                      float_type,
                      training=False,
                      initializer=None):
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')
    decode_ids = tf.keras.layers.Input(shape=(max_answer_length, ),
                                       dtype=tf.int32,
                                       name='decode_ids')

    bert_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Pgnet Model` only uses the sequence_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    sequence_output = bert_model.outputs[1]

    if initializer is None:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=config.initializer_range)

    #
    # Double headed- trained on both span positions and final answer

    coqa_logits_layer = bert_models.BertCoqaLogitsLayer(
        initializer=initializer, float_type=float_type, name='coqa_logits')
    start_logits, end_logits = coqa_logits_layer(sequence_output)

    # figure out the span text from the start logits and end_logits here.
    span_text_ids, span_mask = get_best_span_prediction(
        input_word_ids, start_logits, end_logits, max_seq_length)

    # pgnet_model_layer =modeling.PGNetSummaryModel(config=config ,
    #                                                 float_type=float_type,
    #                                                name='pgnet_summary_model')
    # final_dists, attn_dists = pgnet_model_layer(  span_text_ids,
    #                                               span_mask,
    #                                               answer_ids,
    #                                               answer_mask
    #                                             )
    # coqa = tf.keras.Model(
    #     inputs=[
    #         unique_ids,
    #         answer_ids,
    #         answer_mask ],
    #     outputs=[final_dists, attn_dists,start_logits, end_logits ])

    # PGNet only: end to end - question+context to answer

    coqa_layer = coqalayers.SimpleLSTMSeq2Seq(config=config,
                                              training=training,
                                              name='simple_lstm_seq2seq')

    final_dists = coqa_layer(
        span_text_ids,
        span_mask,
        decode_ids,
    )

    coqa = tf.keras.Model(
        inputs=({
            'unique_ids': unique_ids,
            'input_word_ids': input_word_ids,
            'input_type_ids': input_type_ids,
            'input_mask': input_mask,
            'decode_ids': decode_ids
        }, ),
        outputs=[unique_ids, final_dists, start_logits, end_logits])

    # Bert+PGNet:  end to end
    # pgnet_model_layer = modeling.PGNetDecoderModel(config=config ,
    #                                                  float_type=float_type,
    #                                                  name='pgnet_decoder_model')
    # final_dists, attn_dists = pgnet_model_layer(sequence_output, answer_ids, answer_mask)
    # coqa = tf.keras.Model(
    #     inputs=[
    #         unique_ids,
    #         answer_ids,
    #         answer_mask],
    #     outputs=[unique_ids, final_dists, attn_dists],
    #     name="pgnet_model")

    return coqa, bert_model
Ejemplo n.º 7
0
def coqa_model_bert_rt_transformer(config,
                                   max_seq_length,
                                   max_answer_length,
                                   float_type,
                                   training=False,
                                   initializer=None):
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')

    decode_ids = tf.keras.layers.Input(shape=(max_answer_length, ),
                                       dtype=tf.int32,
                                       name='decode_ids')
    decode_mask = tf.keras.layers.Input(shape=(max_answer_length, ),
                                        dtype=tf.int32,
                                        name='decode_mask')

    bert_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Model` only uses the sequence_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    pooled_output = bert_model.outputs[0]
    sequence_output = bert_model.outputs[1]

    rationale_tag_layer = BertRationaleTagLogitsLayer(
        config.hidden_size,
        initializer=initializer,
        float_type=float_type,
        name='bert_rationale_tag_logits')

    rt_logits = rationale_tag_layer(sequence_output)

    rt_masked_sequence = sequence_output * tf.expand_dims(
        tf.nn.sigmoid(rt_logits), 2)  #it is really a masking process

    coqa_layer = coqalayers.SimpleTransformerDecoder(
        config=config, name='simple_transformer_decoder')

    final_dists = coqa_layer(rt_masked_sequence, input_mask, decode_ids,
                             decode_mask)

    coqa = tf.keras.Model(inputs=({
        'unique_ids': unique_ids,
        'input_word_ids': input_word_ids,
        'input_type_ids': input_type_ids,
        'input_mask': input_mask,
        'decode_ids': decode_ids,
        'decode_mask': decode_mask
    }, ),
                          outputs=[unique_ids, final_dists, rt_logits])

    return coqa, bert_model
Ejemplo n.º 8
0
def coqa_model_span_rationale_tag(config,
                                  max_seq_length,
                                  max_answer_length,
                                  float_type,
                                  training=False,
                                  initializer=None):
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')

    bert_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Model` only uses the sequence_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    pooled_output = bert_model.outputs[0]
    sequence_output = bert_model.outputs[1]

    rationale_tag_layer = BertRationaleTagLogitsLayer(
        config.hidden_size,
        initializer=initializer,
        float_type=float_type,
        name='bert_rationale_tag_logits')

    rt_logits = rationale_tag_layer(sequence_output)

    rt_masked_sequence = sequence_output * tf.expand_dims(
        tf.nn.sigmoid(rt_logits), 2)  #it is really a masking process

    attention_layer = BertSequenceAttentionLayer(
        config.hidden_size,
        initializer=initializer,
        float_type=float_type,
        name='bert_sequence_attention')

    attention = attention_layer(rt_masked_sequence)

    #summary for classification
    h = tf.reduce_sum(tf.expand_dims(attention, axis=2) * sequence_output,
                      axis=1)

    yesnounknown_logits_layer = BertYNULogitsLayer(initializer=initializer,
                                                   float_type=float_type,
                                                   name='yesnounknown_logits')

    ynu_logits = yesnounknown_logits_layer(
        tf.concat([pooled_output, h], axis=1))

    span_logits_layer = BertSpanLogitsLayer(initializer=initializer,
                                            float_type=float_type,
                                            name='bert_span_logits')

    start_logits, end_logits = span_logits_layer(sequence_output)

    coqa = tf.keras.Model(
        inputs=({
            'unique_ids': unique_ids,
            'input_word_ids': input_word_ids,
            'input_type_ids': input_type_ids,
            'input_mask': input_mask
        }, ),
        outputs=[unique_ids, start_logits, end_logits, ynu_logits, rt_logits])

    return coqa, bert_model
Ejemplo n.º 9
0
def coqa_model_bert_2heads(config,
                           max_seq_length,
                           max_answer_length,
                           float_type,
                           training=False,
                           initializer=None):
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')
    decode_ids = tf.keras.layers.Input(shape=(max_answer_length, ),
                                       dtype=tf.int32,
                                       name='decode_ids')
    decode_mask = tf.keras.layers.Input(shape=(max_answer_length, ),
                                        dtype=tf.int32,
                                        name='decode_mask')

    bert_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Model` only uses the sequence_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    sequence_output = bert_model.outputs[1]

    span_logits_layer = BertSpanLogitsLayer(initializer=initializer,
                                            float_type=float_type,
                                            name='squad_logits')

    start_logits, end_logits = span_logits_layer(sequence_output)

    span_mask = get_best_span_mask(start_logits, end_logits)

    new_mask = (tf.cast(tf.logical_not(tf.cast(input_type_ids, tf.bool)),
                        tf.int32) + span_mask) * input_mask

    if initializer is None:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=config.initializer_range)

    #
    coqa_layer = coqalayers.SimpleTransformer(config=config,
                                              name='simple_transformer')

    final_dists = coqa_layer(input_word_ids, new_mask, input_type_ids,
                             decode_ids, decode_mask)

    coqa = tf.keras.Model(
        inputs=({
            'unique_ids': unique_ids,
            'input_word_ids': input_word_ids,
            'input_type_ids': input_type_ids,
            'input_mask': input_mask,
            'decode_ids': decode_ids,
            'decode_mask': decode_mask
        }, ),
        outputs=[unique_ids, final_dists, start_logits, end_logits])

    return coqa, bert_model
Ejemplo n.º 10
0
def coqa_model(bert_config,
               max_seq_length,
               max_answer_length,
               max_oov_size,
               float_type,
               training=False,
               initializer=None):
    """Returns BERT Coqa model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
    float_type: tf.dtype, tf.float32 or tf.bfloat16.
    initializer: Initializer for weights in BertSquadLogitsLayer.

  Returns:
    Two tensors, start logits and end logits, [batch x sequence length].
  """
    unique_ids = tf.keras.layers.Input(shape=(1, ),
                                       dtype=tf.int32,
                                       name='unique_ids')
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='segment_ids')
    answer_ids = tf.keras.layers.Input(shape=(max_answer_length, ),
                                       dtype=tf.int32,
                                       name='answer_ids')
    answer_mask = tf.keras.layers.Input(shape=(max_answer_length, ),
                                        dtype=tf.int32,
                                        name='answer_mask')

    core_model = bert_modeling.get_bert_model(input_word_ids,
                                              input_mask,
                                              input_type_ids,
                                              config=bert_config,
                                              name='bert_model',
                                              float_type=float_type)

    # `Bert Coqa Pgnet Model` only uses the sequnce_output which
    # has dimensionality (batch_size, sequence_length, num_hidden).
    sequence_output = core_model.outputs[1]

    if initializer is None:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range)

    #
    # Double headed- trained on both span positions and final answer
    #
    # coqa_logits_layer = bert_models.BertCoqaLogitsLayer(
    #     initializer=initializer, float_type=float_type, name='coqa_logits')
    # start_logits, end_logits = coqa_logits_layer(sequence_output)
    #
    # #figure out the span text from the start logits and end_logits here.
    # span_text_ids,span_mask=get_best_span_prediction(input_word_ids, start_logits, end_logits,max_seq_length )
    #
    # pgnet_model_layer =modeling.PGNetSummaryModel(config=bert_config ,
    #                                                 float_type=float_type,
    #                                                name='pgnet_summary_model')
    # final_dists, attn_dists = pgnet_model_layer(  span_text_ids,
    #                                               span_mask,
    #                                               answer_ids,
    #                                               answer_mask
    #                                             )
    # coqa = tf.keras.Model(
    #     inputs=[
    #         unique_ids,
    #         answer_ids,
    #         answer_mask ],
    #     outputs=[final_dists, attn_dists,start_logits, end_logits ])

    #PGNet only: end to end - question+context to answer

    coqa_layer = CoqaModel(config=bert_config,
                           training=training,
                           float_type=float_type)

    final_dists = coqa_layer(
        input_word_ids,
        input_mask,
        answer_ids,
        answer_mask,
    )

    coqa = tf.keras.Model(inputs=({
        'unique_ids': unique_ids,
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'answer_ids': answer_ids,
        'answer_mask': answer_mask,
        'input_type_ids': input_type_ids
    }, ),
                          outputs=[unique_ids, final_dists])

    coqa.add_loss(coqa_layer.losses)

    # Bert+PGNet:  end to end
    # pgnet_model_layer = modeling.PGNetDecoderModel(config=bert_config ,
    #                                                  float_type=float_type,
    #                                                  name='pgnet_decoder_model')
    # final_dists, attn_dists = pgnet_model_layer(sequence_output, answer_ids, answer_mask)
    # coqa = tf.keras.Model(
    #     inputs=[
    #         unique_ids,
    #         answer_ids,
    #         answer_mask],
    #     outputs=[unique_ids, final_dists, attn_dists],
    #     name="pgnet_model")

    return coqa, core_model
Ejemplo n.º 11
0
def classifier_model(bert_config,
                     float_type,
                     num_labels,
                     max_seq_length,
                     final_layer_initializer=None,
                     hub_module_url=None):
    """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
    bert_config: BertConfig, the config defines the core BERT model.
    float_type: dtype, tf.float32 or tf.bfloat16.
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
    hub_module_url: (Experimental) TF-Hub path/url to Bert module.

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
    input_word_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_word_ids')
    input_mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                       dtype=tf.int32,
                                       name='input_mask')
    input_type_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                           dtype=tf.int32,
                                           name='input_type_ids')
    if hub_module_url:
        bert_model = hub.KerasLayer(hub_module_url, trainable=True)
        pooled_output, _ = bert_model(
            [input_word_ids, input_mask, input_type_ids])
    else:
        bert_model = modeling.get_bert_model(input_word_ids,
                                             input_mask,
                                             input_type_ids,
                                             config=bert_config,
                                             float_type=float_type)
        pooled_output = bert_model.outputs[0]

    if final_layer_initializer is not None:
        initializer = final_layer_initializer
    else:
        initializer = tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range)

    output = tf.keras.layers.Dropout(
        rate=bert_config.hidden_dropout_prob)(pooled_output)
    output = tf.keras.layers.Dense(num_labels,
                                   kernel_initializer=initializer,
                                   name='output',
                                   dtype=float_type)(output)
    return tf.keras.Model(inputs={
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'input_type_ids': input_type_ids
    },
                          outputs=output), bert_model
Ejemplo n.º 12
0
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
    """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
      initializer: Initializer for weights in BertPretrainLayer.

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """

    input_word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                           name='input_word_ids',
                                           dtype=tf.int32)
    input_mask = tf.keras.layers.Input(shape=(seq_length, ),
                                       name='input_mask',
                                       dtype=tf.int32)
    input_type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                           name='input_type_ids',
                                           dtype=tf.int32)
    masked_lm_positions = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_positions',
        dtype=tf.int32)
    masked_lm_weights = tf.keras.layers.Input(
        shape=(max_predictions_per_seq, ),
        name='masked_lm_weights',
        dtype=tf.int32)
    next_sentence_labels = tf.keras.layers.Input(shape=(1, ),
                                                 name='next_sentence_labels',
                                                 dtype=tf.int32)
    masked_lm_ids = tf.keras.layers.Input(shape=(max_predictions_per_seq, ),
                                          name='masked_lm_ids',
                                          dtype=tf.int32)

    bert_submodel_name = 'bert_model'
    bert_submodel = modeling.get_bert_model(input_word_ids,
                                            input_mask,
                                            input_type_ids,
                                            name=bert_submodel_name,
                                            config=bert_config)
    pooled_output = bert_submodel.outputs[0]
    sequence_output = bert_submodel.outputs[1]

    pretrain_layer = BertPretrainLayer(
        bert_config,
        bert_submodel.get_layer(bert_submodel_name),
        initializer=initializer,
        name='cls')
    lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output,
                                                masked_lm_positions)

    pretrain_loss_layer = BertPretrainLossAndMetricLayer(bert_config)
    output_loss = pretrain_loss_layer(lm_output, sentence_output,
                                      masked_lm_ids, masked_lm_weights,
                                      next_sentence_labels)

    return tf.keras.Model(inputs={
        'input_word_ids': input_word_ids,
        'input_mask': input_mask,
        'input_type_ids': input_type_ids,
        'masked_lm_positions': masked_lm_positions,
        'masked_lm_ids': masked_lm_ids,
        'masked_lm_weights': masked_lm_weights,
        'next_sentence_labels': next_sentence_labels,
    },
                          outputs=output_loss), bert_submodel