def test_predict_and_realize_insertion_batch(self, sources, prediction, gold): """Test predicting and realizing insertion with fake tensorflow models.""" prediction = [ _convert_to_one_hot( [self._vocab_to_id[token] for token in prediction], len(self._vocab_tokens)) ] felix_predictor = predict.FelixPredictor( bert_config_insertion=self._bert_test_tagging_config, bert_config_tagging=self._bert_test_tagging_config, vocab_file=self._vocab_file, model_tagging_filepath=None, model_insertion_filepath=None, label_map_file=self._label_map_path, sequence_length=self._max_sequence_length, is_pointing=True, do_lowercase=True, use_open_vocab=True) insertion_model = DummyPredictorInsertion(prediction) felix_predictor._insertion_model = insertion_model realized_predictions = felix_predictor._predict_and_realize_batch( sources, is_insertion=True) self.assertEqual(realized_predictions[0], gold)
def test_predict_insertion_batch(self): batch_size = 11 felix_predictor = predict.FelixPredictor( bert_config_insertion=self._bert_test_tagging_config, bert_config_tagging=self._bert_test_tagging_config, vocab_file=self._vocab_file, model_tagging_filepath=None, model_insertion_filepath=None, label_map_file=self._label_map_path, sequence_length=self._max_sequence_length, is_pointing=True, do_lowercase=True, use_open_vocab=True) source_batch = [] for i in range(batch_size): source_batch.append(' '.join( random.choices(self._vocab_tokens[7:], k=i + 1))) # Uses a randomly initialized tagging model. predictions = felix_predictor._predict_batch( felix_predictor._convert_source_sentences_into_batch( source_batch, is_insertion=True)[1], is_insertion=True) self.assertLen(predictions, batch_size) for prediction in predictions: self.assertLen(prediction, self._max_predictions)
def test_convert_source_sentences_into_tagging_batch(self): batch_size = 11 felix_predictor = predict.FelixPredictor( bert_config_insertion=self._bert_test_tagging_config, bert_config_tagging=self._bert_test_tagging_config, model_tagging_filepath=None, model_insertion_filepath=None, label_map_file=self._label_map_path, sequence_length=self._max_sequence_length, is_pointing=True, do_lowercase=True, vocab_file=self._vocab_file, use_open_vocab=True) source_batch = [] for i in range(batch_size): # Produce random sentences from the vocab (excluding special tokens). source_batch.append(' '.join( random.choices(self._vocab_tokens[7:], k=i + 1))) batch_dictionaries, batch_list = ( felix_predictor._convert_source_sentences_into_batch( source_batch, is_insertion=False)) # Each input should be of the size (batch_size, max_sequence_length). for value in batch_list.values(): self.assertEqual(value.shape, (batch_size, self._max_sequence_length)) self.assertLen(batch_dictionaries, batch_size) for batch_item in batch_dictionaries: for value in batch_item.values(): self.assertLen(value, self._max_sequence_length)
def test_predict_end_to_end_batch_fake(self, pred, raw_points, sources, gold, gold_with_deletions, insertions): """Test end-to-end with fake tensorflow models.""" felix_predictor = predict.FelixPredictor( bert_config_insertion=self._bert_test_tagging_config, bert_config_tagging=self._bert_test_tagging_config, vocab_file=self._vocab_file, model_tagging_filepath=None, model_insertion_filepath=None, label_map_file=self._label_map_path, sequence_length=self._max_sequence_length, is_pointing=True, do_lowercase=True, use_open_vocab=True) tagging_model = DummyPredictorTagging( _convert_to_one_hot(pred, len(self._label_map)), raw_points) insertions = [ _convert_to_one_hot( [self._vocab_to_id[token] for token in insertions], len(self._vocab_tokens)) ] insertion_model = DummyPredictorInsertion(insertions) felix_predictor._tagging_model = tagging_model felix_predictor._insertion_model = insertion_model taggings_outputs, insertion_outputs = ( felix_predictor.predict_end_to_end_batch(sources)) self.assertEqual(taggings_outputs[0], gold_with_deletions) self.assertEqual(insertion_outputs[0], gold)
def test_predict_end_to_end_batch_random(self): """Test the model predictions end-2-end with randomly initialized models.""" batch_size = 11 felix_predictor = predict.FelixPredictor( bert_config_insertion=self._bert_test_tagging_config, bert_config_tagging=self._bert_test_tagging_config, vocab_file=self._vocab_file, model_tagging_filepath=None, model_insertion_filepath=None, label_map_file=self._label_map_path, sequence_length=self._max_sequence_length, is_pointing=True, do_lowercase=True, use_open_vocab=True) source_batch = [] for i in range(batch_size): source_batch.append(' '.join( random.choices(self._vocab_tokens[8:], k=i + 1))) # Uses a randomly initialized tagging model. predictions_tagging, predictions_insertion = \ felix_predictor.predict_end_to_end_batch(source_batch) self.assertLen(predictions_tagging, batch_size) self.assertLen(predictions_insertion, batch_size)
def test_predict_and_realize_tagging_batch_for_felix_insert( self, pred, sources, gold_with_deletions, insert_after_token): felix_predictor = predict.FelixPredictor( bert_config_insertion=self._bert_test_tagging_config, bert_config_tagging=self._bert_test_tagging_config, model_tagging_filepath=None, model_insertion_filepath=None, label_map_file=self._label_map_path, sequence_length=self._max_sequence_length, is_pointing=False, do_lowercase=True, vocab_file=self._vocab_file, use_open_vocab=True, insert_after_token=insert_after_token) tagging_model = DummyPredictorTagging( _convert_to_one_hot(pred, len(self._label_map))) felix_predictor._tagging_model = tagging_model realized_predictions = felix_predictor._predict_and_realize_batch( sources, is_insertion=False) self.assertEqual(realized_predictions[0], gold_with_deletions)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if not FLAGS.use_open_vocab: raise ValueError('Currently only use_open_vocab=True is supported') label_map = utils.read_label_map(FLAGS.label_map_file) bert_config_tagging = configs.BertConfig.from_json_file( FLAGS.bert_config_tagging) bert_config_insertion = configs.BertConfig.from_json_file( FLAGS.bert_config_insertion) if FLAGS.tpu is not None: cluster_resolver = distribute_utils.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.TPUStrategy(cluster_resolver) with strategy.scope(): predictor = predict.FelixPredictor( bert_config_tagging=bert_config_tagging, bert_config_insertion=bert_config_insertion, model_tagging_filepath=FLAGS.model_tagging_filepath, model_insertion_filepath=FLAGS.model_insertion_filepath, vocab_file=FLAGS.vocab_file, label_map=label_map, sequence_length=FLAGS.max_seq_length, max_predictions=FLAGS.max_predictions_per_seq, do_lowercase=FLAGS.do_lower_case, use_open_vocab=FLAGS.use_open_vocab, is_pointing=FLAGS.use_pointing, insert_after_token=FLAGS.insert_after_token, special_glue_string_for_joining_sources=FLAGS .special_glue_string_for_joining_sources) else: predictor = predict.FelixPredictor( bert_config_tagging=bert_config_tagging, bert_config_insertion=bert_config_insertion, model_tagging_filepath=FLAGS.model_tagging_filepath, model_insertion_filepath=FLAGS.model_insertion_filepath, vocab_file=FLAGS.vocab_file, label_map_file=FLAGS.label_map_file, sequence_length=FLAGS.max_seq_length, max_predictions=FLAGS.max_predictions_per_seq, do_lowercase=FLAGS.do_lower_case, use_open_vocab=FLAGS.use_open_vocab, is_pointing=FLAGS.use_pointing, insert_after_token=FLAGS.insert_after_token, special_glue_string_for_joining_sources=FLAGS .special_glue_string_for_joining_sources) source_batch = [] target_batch = [] num_predicted = 0 with tf.io.gfile.GFile(FLAGS.predict_output_file, 'w') as writer: for source_batch, target_batch in batch_generator(): predicted_tags, predicted_inserts = predictor.predict_end_to_end_batch( source_batch) num_predicted += len(source_batch) logging.log_every_n(logging.INFO, f'{num_predicted} predicted.', 200) for source_input, target_output, predicted_tag, predicted_insert in zip( source_batch, target_batch, predicted_tags, predicted_inserts): writer.write(f'{source_input}\t{predicted_tag}\t{predicted_insert}\t' f'{target_output}\n')