def main(_): predictor = orqa_model.get_predictor(FLAGS.model_dir) example_count = 0 correct_count = 0 with tf.io.gfile.GFile(FLAGS.dataset_path) as dataset_file: for i, line in enumerate(dataset_file): example = json.loads(line) question = example["question"] answers = example["answer"] predictions = predictor(question) predicted_answer = six.ensure_text(predictions["answer"], errors="ignore") is_correct = eval_utils.is_correct( answers=[six.ensure_text(a) for a in answers], prediction=predicted_answer, is_regex=False) correct_count += int(is_correct) example_count += 1 if FLAGS.print_prediction_samples and i & (i - 1) == 0: logging.info("[%d] '%s' -> '%s'", i, question, predicted_answer) logging.info("Accuracy: %.4f (%d/%d)", correct_count / float(example_count), correct_count, example_count)
def main(_): predictor = orqa_model.get_predictor(FLAGS.model_dir) example_count = 0 correct_count = 0 predictions_path = os.path.join(FLAGS.model_dir, "predictions.jsonl") with tf.io.gfile.GFile(predictions_path, "w") as predictions_file: with tf.io.gfile.GFile(FLAGS.dataset_path) as dataset_file: for line in dataset_file: example = json.loads(line) question = example["question"] answers = example["answer"] predictions = predictor(question) predicted_answer = six.ensure_text(predictions["answer"], errors="ignore") is_correct = eval_utils.is_correct( answers=[six.ensure_text(a) for a in answers], prediction=predicted_answer, is_regex=False) predictions_file.write( json.dumps( dict(question=question, prediction=predicted_answer, predicted_context=six.ensure_text( predictions["orig_block"], errors="ignore"), correct=is_correct, answer=answers))) predictions_file.write("\n") correct_count += int(is_correct) example_count += 1 tf.logging.info("Accuracy: %.4f (%d/%d)", correct_count / float(example_count), correct_count, example_count)
def main(_): predictor = orqa_model.get_predictor(FLAGS.model_dir) with tf.io.gfile.GFile(FLAGS.predictions_path, "w") as predictions_file: with tf.io.gfile.GFile(FLAGS.dataset_path) as dataset_file: for i, line in enumerate(dataset_file): example = json.loads(line) question = example["question"] predictions = predictor(question) predicted_answer = six.ensure_text(predictions["answer"], errors="ignore") example["prediction"] = predicted_answer predictions_file.write(json.dumps(example)) predictions_file.write("\n") if FLAGS.print_prediction_samples and i & (i - 1) == 0: logging.info("[%d] '%s' -> '%s'", i, question, predicted_answer)
def main(_): predictor = orqa_model.get_predictor(FLAGS.model_dir) # Run once to initialize the block records. predictor("") web_path = FLAGS.web_path if not web_path: web_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "web") env = jinja2.Environment(loader=jinja2.FileSystemLoader(web_path)) application = tornado.wsgi.WSGIApplication([(r"/", MainHandler, { "env": env, "predictor": predictor, })]) tf.logging.info("READY!") server = simple_server.make_server("", FLAGS.port, application) server.serve_forever()