Пример #1
0
    def test_predict(self):
        task_config = tagging.TaggingConfig(
            model=tagging.ModelConfig(encoder=self._encoder_config),
            train_data=self._train_data_config,
            class_names=["O", "B-PER", "I-PER"])
        task = tagging.TaggingTask(task_config)
        model = task.build_model()

        test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record")
        seq_length = 16
        num_examples = 100
        _create_fake_dataset(test_data_path,
                             seq_length=seq_length,
                             num_labels=len(task_config.class_names),
                             num_examples=num_examples)
        test_data_config = tagging_dataloader.TaggingDataConfig(
            input_path=test_data_path,
            seq_length=seq_length,
            is_training=False,
            global_batch_size=16,
            drop_remainder=False,
            include_sentence_id=True)

        results = tagging.predict(task, test_data_config, model)
        self.assertLen(results, num_examples)
        self.assertLen(results[0], 3)
Пример #2
0
def write_tagging(task, model, input_file, output_file, predict_batch_size,
                  seq_length):
  """Makes tagging predictions and writes to output file."""
  data_config = tagging_dataloader.TaggingDataConfig(
      input_path=input_file,
      is_training=False,
      seq_length=seq_length,
      global_batch_size=predict_batch_size,
      drop_remainder=False,
      include_sentence_id=True)
  results = tagging.predict(task, data_config, model)
  class_names = task.task_config.class_names
  last_sentence_id = -1

  with tf.io.gfile.GFile(output_file, 'w') as writer:
    for sentence_id, _, predict_ids in results:
      token_labels = [class_names[x] for x in predict_ids]
      assert sentence_id == last_sentence_id or (
          sentence_id == last_sentence_id + 1)

      if sentence_id != last_sentence_id and last_sentence_id != -1:
        writer.write('\n')

      writer.write('\n'.join(token_labels))
      writer.write('\n')
      last_sentence_id = sentence_id