예제 #1
0
    def test_forward_pass(self,
                          use_pointing=False,
                          query_transformer=False,
                          is_training=True):
        """Randomly generate and run different configuarations for Felix Tagger."""
        # Ensures reproducibility.

        # Setup.
        sequence_length = 7
        vocab_size = 11
        bert_hidden_size = 13
        bert_num_hidden_layers = 1
        bert_num_attention_heads = 1
        bert_intermediate_size = 4
        bert_type_vocab_size = 2
        bert_max_position_embeddings = sequence_length
        bert_encoder = networks.BertEncoder(
            vocab_size=vocab_size,
            hidden_size=bert_hidden_size,
            num_layers=bert_num_hidden_layers,
            num_attention_heads=bert_num_attention_heads,
            intermediate_size=bert_intermediate_size,
            sequence_length=sequence_length,
            max_sequence_length=bert_max_position_embeddings,
            type_vocab_size=bert_type_vocab_size)
        bert_config = configs.BertConfig(
            vocab_size,
            hidden_size=bert_hidden_size,
            num_hidden_layers=bert_num_hidden_layers,
            num_attention_heads=bert_num_attention_heads,
            intermediate_size=bert_intermediate_size,
            type_vocab_size=bert_type_vocab_size,
            max_position_embeddings=bert_max_position_embeddings)
        batch_size = 17
        edit_tags_size = 19
        bert_config.num_classes = edit_tags_size
        bert_config.query_size = 23
        bert_config.query_transformer = query_transformer

        tagger = felix_tagger.FelixTagger(bert_encoder,
                                          bert_config=bert_config,
                                          seq_length=sequence_length,
                                          use_pointing=use_pointing,
                                          is_training=is_training)

        # Create inputs.
        np.random.seed(42)
        input_word_ids = np.random.randint(vocab_size - 1,
                                           size=(batch_size, sequence_length))
        input_mask = np.random.randint(1, size=(batch_size, sequence_length))
        input_type_ids = np.ones((batch_size, sequence_length))
        edit_tags = np.random.randint(edit_tags_size - 2,
                                      size=(batch_size, sequence_length))

        # Run the model.
        if is_training:
            output = tagger(
                [input_word_ids, input_type_ids, input_mask, edit_tags])
        else:
            output = tagger([input_word_ids, input_type_ids, input_mask])

        # Check output shapes.
        if use_pointing:
            tag_logits, pointing_logits = output
            self.assertEqual(pointing_logits.shape,
                             (batch_size, sequence_length, sequence_length))
        else:
            tag_logits = output[0]
        self.assertEqual(tag_logits.shape,
                         (batch_size, sequence_length, edit_tags_size))
예제 #2
0
def get_tagging_model(bert_config,
                      seq_length,
                      use_pointing = True,
                      pointing_weight = 1.0,
                      is_training = True):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      use_pointing: If FELIX should use a pointer (reordering) model.
      pointing_weight: How much to weigh the pointing loss, in contrast to
        tagging loss. Note, if pointing is set to false this is ignored.
      is_training: Will the model be trained or is it inferance time.

  Returns:
      Felix model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)

  bert_encoder = networks.BertEncoder(
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      activation=activations.gelu,
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=seq_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range))

  felix_model = felix_tagger.FelixTagger(
      bert_encoder,
      seq_length=seq_length,
      use_pointing=use_pointing,
      bert_config=bert_config,
      is_training=is_training)
  felix_inputs = [input_word_ids, input_mask, input_type_ids]
  if is_training:
    edit_tags = tf.keras.layers.Input(
        shape=(seq_length,), name='edit_tags', dtype=tf.int32)

    felix_inputs.append(edit_tags)
    felix_outputs = felix_model(felix_inputs)
    labels_mask = tf.keras.layers.Input(
        shape=(seq_length,), name='labels_mask', dtype=tf.float32)
    felix_inputs.append(labels_mask)
    if use_pointing:
      pointers = tf.keras.layers.Input(
          shape=(seq_length,), name='pointers', dtype=tf.int32)
      tag_logits, pointing_logits = felix_outputs
      felix_inputs.append(pointers)
    else:
      tag_logits = felix_outputs[0]
      pointing_logits = None
      pointers = None
    loss_function = FelixTagLoss(use_pointing, pointing_weight)
    felix_tag_loss = loss_function(tag_logits, edit_tags, input_mask,
                                   labels_mask, pointing_logits, pointers)
    keras_model = tf.keras.Model(inputs=felix_inputs, outputs=felix_tag_loss)
  else:
    felix_inputs = [input_word_ids, input_mask, input_type_ids]
    felix_outputs = felix_model(felix_inputs)
    if use_pointing:
      tag_logits, pointing_logits = felix_outputs
      keras_model = tf.keras.Model(
          inputs=felix_inputs, outputs=[tag_logits, pointing_logits])
    else:
      tag_logits = felix_outputs[0]
      pointing_logits = None
      keras_model = tf.keras.Model(inputs=felix_inputs, outputs=tag_logits)
  return keras_model, bert_encoder