def test_forward_pass(self, use_pointing=False, query_transformer=False, is_training=True): """Randomly generate and run different configuarations for Felix Tagger.""" # Ensures reproducibility. # Setup. sequence_length = 7 vocab_size = 11 bert_hidden_size = 13 bert_num_hidden_layers = 1 bert_num_attention_heads = 1 bert_intermediate_size = 4 bert_type_vocab_size = 2 bert_max_position_embeddings = sequence_length bert_encoder = networks.BertEncoder( vocab_size=vocab_size, hidden_size=bert_hidden_size, num_layers=bert_num_hidden_layers, num_attention_heads=bert_num_attention_heads, intermediate_size=bert_intermediate_size, sequence_length=sequence_length, max_sequence_length=bert_max_position_embeddings, type_vocab_size=bert_type_vocab_size) bert_config = configs.BertConfig( vocab_size, hidden_size=bert_hidden_size, num_hidden_layers=bert_num_hidden_layers, num_attention_heads=bert_num_attention_heads, intermediate_size=bert_intermediate_size, type_vocab_size=bert_type_vocab_size, max_position_embeddings=bert_max_position_embeddings) batch_size = 17 edit_tags_size = 19 bert_config.num_classes = edit_tags_size bert_config.query_size = 23 bert_config.query_transformer = query_transformer tagger = felix_tagger.FelixTagger(bert_encoder, bert_config=bert_config, seq_length=sequence_length, use_pointing=use_pointing, is_training=is_training) # Create inputs. np.random.seed(42) input_word_ids = np.random.randint(vocab_size - 1, size=(batch_size, sequence_length)) input_mask = np.random.randint(1, size=(batch_size, sequence_length)) input_type_ids = np.ones((batch_size, sequence_length)) edit_tags = np.random.randint(edit_tags_size - 2, size=(batch_size, sequence_length)) # Run the model. if is_training: output = tagger( [input_word_ids, input_type_ids, input_mask, edit_tags]) else: output = tagger([input_word_ids, input_type_ids, input_mask]) # Check output shapes. if use_pointing: tag_logits, pointing_logits = output self.assertEqual(pointing_logits.shape, (batch_size, sequence_length, sequence_length)) else: tag_logits = output[0] self.assertEqual(tag_logits.shape, (batch_size, sequence_length, edit_tags_size))
def get_tagging_model(bert_config, seq_length, use_pointing = True, pointing_weight = 1.0, is_training = True): """Returns model to be used for pre-training. Args: bert_config: Configuration that defines the core BERT model. seq_length: Maximum sequence length of the training data. use_pointing: If FELIX should use a pointer (reordering) model. pointing_weight: How much to weigh the pointing loss, in contrast to tagging loss. Note, if pointing is set to false this is ignored. is_training: Will the model be trained or is it inferance time. Returns: Felix model as well as core BERT submodel from which to save weights after pretraining. """ input_word_ids = tf.keras.layers.Input( shape=(seq_length,), name='input_word_ids', dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(seq_length,), name='input_mask', dtype=tf.int32) input_type_ids = tf.keras.layers.Input( shape=(seq_length,), name='input_type_ids', dtype=tf.int32) bert_encoder = networks.BertEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=activations.gelu, dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, sequence_length=seq_length, max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range)) felix_model = felix_tagger.FelixTagger( bert_encoder, seq_length=seq_length, use_pointing=use_pointing, bert_config=bert_config, is_training=is_training) felix_inputs = [input_word_ids, input_mask, input_type_ids] if is_training: edit_tags = tf.keras.layers.Input( shape=(seq_length,), name='edit_tags', dtype=tf.int32) felix_inputs.append(edit_tags) felix_outputs = felix_model(felix_inputs) labels_mask = tf.keras.layers.Input( shape=(seq_length,), name='labels_mask', dtype=tf.float32) felix_inputs.append(labels_mask) if use_pointing: pointers = tf.keras.layers.Input( shape=(seq_length,), name='pointers', dtype=tf.int32) tag_logits, pointing_logits = felix_outputs felix_inputs.append(pointers) else: tag_logits = felix_outputs[0] pointing_logits = None pointers = None loss_function = FelixTagLoss(use_pointing, pointing_weight) felix_tag_loss = loss_function(tag_logits, edit_tags, input_mask, labels_mask, pointing_logits, pointers) keras_model = tf.keras.Model(inputs=felix_inputs, outputs=felix_tag_loss) else: felix_inputs = [input_word_ids, input_mask, input_type_ids] felix_outputs = felix_model(felix_inputs) if use_pointing: tag_logits, pointing_logits = felix_outputs keras_model = tf.keras.Model( inputs=felix_inputs, outputs=[tag_logits, pointing_logits]) else: tag_logits = felix_outputs[0] pointing_logits = None keras_model = tf.keras.Model(inputs=felix_inputs, outputs=tag_logits) return keras_model, bert_encoder