예제 #1
0
def write_question_answering(task,
                             model,
                             input_file,
                             output_file,
                             predict_batch_size,
                             seq_length,
                             tokenization,
                             vocab_file,
                             do_lower_case,
                             version_2_with_negative=False):
  """Makes question answering predictions and writes to output file."""
  data_config = question_answering_dataloader.QADataConfig(
      do_lower_case=do_lower_case,
      doc_stride=128,
      drop_remainder=False,
      global_batch_size=predict_batch_size,
      input_path=input_file,
      is_training=False,
      query_length=64,
      seq_length=seq_length,
      tokenization=tokenization,
      version_2_with_negative=version_2_with_negative,
      vocab_file=vocab_file)
  all_predictions, _, _ = question_answering.predict(task, data_config, model)
  with tf.io.gfile.GFile(output_file, 'w') as writer:
    writer.write(json.dumps(all_predictions, indent=4) + '\n')
    def test_predict(self, version_2_with_negative):
        validation_data = self._get_validation_data_config(
            version_2_with_negative=version_2_with_negative)

        config = question_answering.QuestionAnsweringConfig(
            model=question_answering.ModelConfig(encoder=self._encoder_config),
            train_data=self._train_data_config,
            validation_data=validation_data)
        task = question_answering.QuestionAnsweringTask(config)
        model = task.build_model()

        all_predictions, all_nbest, scores_diff = question_answering.predict(
            task, validation_data, model)
        self.assertLen(all_predictions, 1)
        self.assertLen(all_nbest, 1)
        if version_2_with_negative:
            self.assertLen(scores_diff, 1)
        else:
            self.assertEmpty(scores_diff)