Exemplo n.º 1
0
    def test_predict_and_realize_insertion_batch(self, sources, prediction,
                                                 gold):
        """Test predicting and realizing insertion with fake tensorflow models."""
        prediction = [
            _convert_to_one_hot(
                [self._vocab_to_id[token] for token in prediction],
                len(self._vocab_tokens))
        ]

        felix_predictor = predict.FelixPredictor(
            bert_config_insertion=self._bert_test_tagging_config,
            bert_config_tagging=self._bert_test_tagging_config,
            vocab_file=self._vocab_file,
            model_tagging_filepath=None,
            model_insertion_filepath=None,
            label_map_file=self._label_map_path,
            sequence_length=self._max_sequence_length,
            is_pointing=True,
            do_lowercase=True,
            use_open_vocab=True)

        insertion_model = DummyPredictorInsertion(prediction)
        felix_predictor._insertion_model = insertion_model
        realized_predictions = felix_predictor._predict_and_realize_batch(
            sources, is_insertion=True)
        self.assertEqual(realized_predictions[0], gold)
Exemplo n.º 2
0
    def test_predict_insertion_batch(self):

        batch_size = 11
        felix_predictor = predict.FelixPredictor(
            bert_config_insertion=self._bert_test_tagging_config,
            bert_config_tagging=self._bert_test_tagging_config,
            vocab_file=self._vocab_file,
            model_tagging_filepath=None,
            model_insertion_filepath=None,
            label_map_file=self._label_map_path,
            sequence_length=self._max_sequence_length,
            is_pointing=True,
            do_lowercase=True,
            use_open_vocab=True)
        source_batch = []
        for i in range(batch_size):
            source_batch.append(' '.join(
                random.choices(self._vocab_tokens[7:], k=i + 1)))
        # Uses a randomly initialized tagging model.
        predictions = felix_predictor._predict_batch(
            felix_predictor._convert_source_sentences_into_batch(
                source_batch, is_insertion=True)[1],
            is_insertion=True)
        self.assertLen(predictions, batch_size)

        for prediction in predictions:
            self.assertLen(prediction, self._max_predictions)
Exemplo n.º 3
0
    def test_convert_source_sentences_into_tagging_batch(self):
        batch_size = 11
        felix_predictor = predict.FelixPredictor(
            bert_config_insertion=self._bert_test_tagging_config,
            bert_config_tagging=self._bert_test_tagging_config,
            model_tagging_filepath=None,
            model_insertion_filepath=None,
            label_map_file=self._label_map_path,
            sequence_length=self._max_sequence_length,
            is_pointing=True,
            do_lowercase=True,
            vocab_file=self._vocab_file,
            use_open_vocab=True)
        source_batch = []
        for i in range(batch_size):
            # Produce random sentences from the vocab (excluding special tokens).
            source_batch.append(' '.join(
                random.choices(self._vocab_tokens[7:], k=i + 1)))
        batch_dictionaries, batch_list = (
            felix_predictor._convert_source_sentences_into_batch(
                source_batch, is_insertion=False))
        # Each input should be of the size (batch_size, max_sequence_length).
        for value in batch_list.values():
            self.assertEqual(value.shape,
                             (batch_size, self._max_sequence_length))

        self.assertLen(batch_dictionaries, batch_size)
        for batch_item in batch_dictionaries:
            for value in batch_item.values():
                self.assertLen(value, self._max_sequence_length)
Exemplo n.º 4
0
 def test_predict_end_to_end_batch_fake(self, pred, raw_points, sources,
                                        gold, gold_with_deletions,
                                        insertions):
     """Test end-to-end with fake tensorflow models."""
     felix_predictor = predict.FelixPredictor(
         bert_config_insertion=self._bert_test_tagging_config,
         bert_config_tagging=self._bert_test_tagging_config,
         vocab_file=self._vocab_file,
         model_tagging_filepath=None,
         model_insertion_filepath=None,
         label_map_file=self._label_map_path,
         sequence_length=self._max_sequence_length,
         is_pointing=True,
         do_lowercase=True,
         use_open_vocab=True)
     tagging_model = DummyPredictorTagging(
         _convert_to_one_hot(pred, len(self._label_map)), raw_points)
     insertions = [
         _convert_to_one_hot(
             [self._vocab_to_id[token] for token in insertions],
             len(self._vocab_tokens))
     ]
     insertion_model = DummyPredictorInsertion(insertions)
     felix_predictor._tagging_model = tagging_model
     felix_predictor._insertion_model = insertion_model
     taggings_outputs, insertion_outputs = (
         felix_predictor.predict_end_to_end_batch(sources))
     self.assertEqual(taggings_outputs[0], gold_with_deletions)
     self.assertEqual(insertion_outputs[0], gold)
Exemplo n.º 5
0
    def test_predict_end_to_end_batch_random(self):
        """Test the model predictions end-2-end with randomly initialized models."""

        batch_size = 11
        felix_predictor = predict.FelixPredictor(
            bert_config_insertion=self._bert_test_tagging_config,
            bert_config_tagging=self._bert_test_tagging_config,
            vocab_file=self._vocab_file,
            model_tagging_filepath=None,
            model_insertion_filepath=None,
            label_map_file=self._label_map_path,
            sequence_length=self._max_sequence_length,
            is_pointing=True,
            do_lowercase=True,
            use_open_vocab=True)
        source_batch = []

        for i in range(batch_size):
            source_batch.append(' '.join(
                random.choices(self._vocab_tokens[8:], k=i + 1)))
        # Uses a randomly initialized tagging model.
        predictions_tagging, predictions_insertion = \
            felix_predictor.predict_end_to_end_batch(source_batch)
        self.assertLen(predictions_tagging, batch_size)
        self.assertLen(predictions_insertion, batch_size)
Exemplo n.º 6
0
 def test_predict_and_realize_tagging_batch_for_felix_insert(
         self, pred, sources, gold_with_deletions, insert_after_token):
     felix_predictor = predict.FelixPredictor(
         bert_config_insertion=self._bert_test_tagging_config,
         bert_config_tagging=self._bert_test_tagging_config,
         model_tagging_filepath=None,
         model_insertion_filepath=None,
         label_map_file=self._label_map_path,
         sequence_length=self._max_sequence_length,
         is_pointing=False,
         do_lowercase=True,
         vocab_file=self._vocab_file,
         use_open_vocab=True,
         insert_after_token=insert_after_token)
     tagging_model = DummyPredictorTagging(
         _convert_to_one_hot(pred, len(self._label_map)))
     felix_predictor._tagging_model = tagging_model
     realized_predictions = felix_predictor._predict_and_realize_batch(
         sources, is_insertion=False)
     self.assertEqual(realized_predictions[0], gold_with_deletions)
Exemplo n.º 7
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if not FLAGS.use_open_vocab:
    raise ValueError('Currently only use_open_vocab=True is supported')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  bert_config_tagging = configs.BertConfig.from_json_file(
      FLAGS.bert_config_tagging)
  bert_config_insertion = configs.BertConfig.from_json_file(
      FLAGS.bert_config_insertion)
  if FLAGS.tpu is not None:
    cluster_resolver = distribute_utils.tpu_initialize(FLAGS.tpu)
    strategy = tf.distribute.TPUStrategy(cluster_resolver)
    with strategy.scope():
      predictor = predict.FelixPredictor(
          bert_config_tagging=bert_config_tagging,
          bert_config_insertion=bert_config_insertion,
          model_tagging_filepath=FLAGS.model_tagging_filepath,
          model_insertion_filepath=FLAGS.model_insertion_filepath,
          vocab_file=FLAGS.vocab_file,
          label_map=label_map,
          sequence_length=FLAGS.max_seq_length,
          max_predictions=FLAGS.max_predictions_per_seq,
          do_lowercase=FLAGS.do_lower_case,
          use_open_vocab=FLAGS.use_open_vocab,
          is_pointing=FLAGS.use_pointing,
          insert_after_token=FLAGS.insert_after_token,
          special_glue_string_for_joining_sources=FLAGS
          .special_glue_string_for_joining_sources)
  else:
    predictor = predict.FelixPredictor(
        bert_config_tagging=bert_config_tagging,
        bert_config_insertion=bert_config_insertion,
        model_tagging_filepath=FLAGS.model_tagging_filepath,
        model_insertion_filepath=FLAGS.model_insertion_filepath,
        vocab_file=FLAGS.vocab_file,
        label_map_file=FLAGS.label_map_file,
        sequence_length=FLAGS.max_seq_length,
        max_predictions=FLAGS.max_predictions_per_seq,
        do_lowercase=FLAGS.do_lower_case,
        use_open_vocab=FLAGS.use_open_vocab,
        is_pointing=FLAGS.use_pointing,
        insert_after_token=FLAGS.insert_after_token,
        special_glue_string_for_joining_sources=FLAGS
        .special_glue_string_for_joining_sources)

  source_batch = []
  target_batch = []
  num_predicted = 0
  with tf.io.gfile.GFile(FLAGS.predict_output_file, 'w') as writer:
    for source_batch, target_batch in batch_generator():
      predicted_tags, predicted_inserts = predictor.predict_end_to_end_batch(
          source_batch)
      num_predicted += len(source_batch)
      logging.log_every_n(logging.INFO, f'{num_predicted} predicted.', 200)
      for source_input, target_output, predicted_tag, predicted_insert in zip(
          source_batch, target_batch, predicted_tags, predicted_inserts):
        writer.write(f'{source_input}\t{predicted_tag}\t{predicted_insert}\t'
                     f'{target_output}\n')