def get_predictions(self, batch): """ Return: """ starts, ends = [], [] for ckpt, FLAGS in zip(self.ckpts, self.flags): qa_model = QAModel(FLAGS, self.id2word, self.word2id, self.emb_matrix, self.id2idf, is_training=False) with tf.Session(config=self.tf_config) as session: qa_model.initialize_from_checkpoint(session, ckpt, True) pred_start_pos, pred_end_pos = qa_model.get_start_end_pos( session, batch) starts.append(pred_start_pos) ends.append(pred_end_pos) del qa_model tf.reset_default_graph() starts, _ = stats.mode(np.array(starts)) ends, _ = stats.mode(np.array(ends)) return (starts[0].astype(np.int), ends[0].astype(np.int))
def main(unused_argv): # Print an error message if you've entered flags incorrectly if len(unused_argv) != 1: raise Exception("There is a problem with how you entered flags: %s" % unused_argv) # Check for Python 2 if sys.version_info[0] != 2: raise Exception( "ERROR: You must use Python 2 but you are running Python %i" % sys.version_info[0]) # Define path for glove vecs FLAGS.glove_path = FLAGS.glove_path or \ os.path.join(DEFAULT_DATA_DIR, "glove.6B.{}d.txt".format(FLAGS.embedding_size)) # Load embedding matrix and vocab mappings timer.start("glove_getter") emb_matrix, word2id, id2word = get_glove(FLAGS.glove_path, FLAGS.embedding_size) id2idf = get_idf(os.path.abspath(FLAGS.idf_path), word2id) logger.warn("Get glove embedding of size {} takes {:.2f} s".format( FLAGS.embedding_size, timer.stop("glove_getter"))) # Print out Tensorflow version # print "This code was developed and tested on TensorFlow 1.4.1. Your TensorFlow version: %s" % tf.__version__ ensumble = FLAGS.ensumble print(ensumble) if not ensumble and not FLAGS.attn_layer and not FLAGS.train_dir and FLAGS.mode != "official_eval": raise Exception( "You need to specify either --attn_layer or --train_dir") # Define train_dir if not FLAGS.experiment_name: FLAGS.experiment_name = "A_{}_E_{}_D_{}".format( FLAGS.attn_layer, FLAGS.embedding_size, FLAGS.dropout) checkptr_name = FLAGS.experiment_name + "/glove{}".format( FLAGS.embedding_size) FLAGS.train_dir = FLAGS.train_dir or\ os.path.join(EXPERIMENTS_DIR, checkptr_name) # Initialize bestmodel directory bestmodel_dir = os.path.join(FLAGS.train_dir, "best_checkpoint") # Get filepaths to train/dev datafiles for tokenized queries, contexts and answers train_context_path = os.path.join(FLAGS.data_dir, "train.context") train_qn_path = os.path.join(FLAGS.data_dir, "train.question") train_ans_path = os.path.join(FLAGS.data_dir, "train.span") dev_context_path = os.path.join(FLAGS.data_dir, "dev.context") dev_qn_path = os.path.join(FLAGS.data_dir, "dev.question") dev_ans_path = os.path.join(FLAGS.data_dir, "dev.span") # Some GPU settings config = tf.ConfigProto() config.gpu_options.allow_growth = True is_training = (FLAGS.mode == "train") if not ensumble: # Initialize model qa_model = QAModel(FLAGS, id2word, word2id, emb_matrix, id2idf, is_training) else: ensumbler = Ensumbler(ensumble, config, id2word, word2id, emb_matrix, id2idf) # Split by mode if FLAGS.mode == "train": # Setup train dir and logfile if not os.path.exists(FLAGS.train_dir): os.makedirs(FLAGS.train_dir) file_handler = logging.FileHandler( os.path.join(FLAGS.train_dir, "log.txt")) logging.getLogger().addHandler(file_handler) # Save a record of flags as a .json file in train_dir with open(os.path.join(FLAGS.train_dir, "flags.json"), 'w') as fout: json.dump(FLAGS.__flags, fout) # Make bestmodel dir if necessary if not os.path.exists(bestmodel_dir): os.makedirs(bestmodel_dir) with tf.Session(config=config) as sess: # Load most recent model qa_model.initialize_from_checkpoint(sess, FLAGS.train_dir, expect_exists=False) # Train qa_model.train(sess, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path) elif FLAGS.mode == "show_examples": with tf.Session(config=config) as sess: # Load best model qa_model.initialize_from_checkpoint(sess, bestmodel_dir, expect_exists=True) # Show examples with F1/EM scores f1, em = qa_model.check_f1_em(sess, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=10, print_to_screen=True) logger.info("Dev: F1 = {0:.3}, EM = {0:.3}".format(f1, em)) elif FLAGS.mode == "eval": if ensumble: # train train_f1, train_em = ensumbler.check_f1_em(train_context_path, train_qn_path, train_ans_path, "train", FLAGS.n_eval) # dev dev_f1, dev_em = ensumbler.check_f1_em(dev_context_path, dev_qn_path, dev_ans_path, "dev", FLAGS.n_eval) else: with tf.Session(config=config) as sess: # Load best model qa_model.initialize_from_checkpoint(sess, FLAGS.ckpt_load_dir, expect_exists=True) logger.info("Model initialzed from checkpoint") # train train_f1, train_em = qa_model.check_f1_em( sess, train_context_path, train_qn_path, train_ans_path, "train", num_samples=10, print_to_screen=False) # dev dev_f1, dev_em = qa_model.check_f1_em(sess, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=10, print_to_screen=False) logger.error("Train: F1 = {:.3}, EM = {:.3}".format( train_f1, train_em)) logger.error("Dev: F1 = {:.3}, EM = {:.3}".format(dev_f1, dev_em)) elif FLAGS.mode == "official_eval": if not ensumble: if FLAGS.json_in_path == "": raise Exception( "For official_eval mode, you need to specify --json_in_path" ) if FLAGS.ckpt_load_dir == "": raise Exception( "For official_eval mode, you need to specify --ckpt_load_dir" ) # Read the JSON data from file qn_uuid_data, context_token_data, qn_token_data = get_json_data( FLAGS.json_in_path) if ensumble: answers_dict = ensumbler.generate_answers(qn_uuid_data, context_token_data, qn_token_data) else: with tf.Session(config=config) as sess: # Load model from ckpt_load_dir qa_model.initialize_from_checkpoint(sess, FLAGS.ckpt_load_dir, expect_exists=True) # Get a predicted answer for each example in the data # Return a mapping answers_dict from uuid to answer answers_dict = generate_answers(sess, qa_model, word2id, id2idf, qn_uuid_data, context_token_data, qn_token_data) # Write the uuid->answer mapping a to json file in root dir print "Writing predictions to %s..." % FLAGS.json_out_path with io.open(FLAGS.json_out_path, 'w', encoding='utf-8') as f: f.write(unicode(json.dumps(answers_dict, ensure_ascii=False))) print "Wrote predictions to %s" % FLAGS.json_out_path else: raise Exception("Unexpected value of FLAGS.mode: %s" % FLAGS.mode)