def get_all_word_count(num_workers, db_path, out_path, titles_only=False): db = DocDB(db_path) all_titles = db.get_doc_titles() # Setup worker pool workers = ProcessPool( num_workers, initializer=init, initargs=[db_path] ) word_counts = Counter() if not titles_only: with tqdm(total=len(all_titles), desc='word count') as pbar: for w_count in tqdm(workers.imap_unordered(get_word_count, all_titles)): word_counts.update(w_count) pbar.update() else: with tqdm(total=len(all_titles), desc='title word count') as pbar: for tok_title in tqdm(workers.imap_unordered(tokenize, all_titles)): word_counts.update(tok_title.words()) pbar.update() with open(out_path, 'w') as f: for k, v in tqdm(word_counts.most_common(), desc='writing counts'): f.write(f"{k}\t{v}\n")
def init(db_path, full_doc_db_path, ranker_path): global PROCESS_TOK, PROCESS_DB, PROCESS_RANKER, PROCESS_FULL_DOC_DB, DOC_TITLES PROCESS_TOK = CoreNLPTokenizer() Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) PROCESS_DB = DocDB(db_path) Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) PROCESS_FULL_DOC_DB = DocDB(full_doc_db_path, full_docs=True) Finalize(PROCESS_FULL_DOC_DB, PROCESS_FULL_DOC_DB.close, exitpriority=100) PROCESS_RANKER = TfidfDocRanker(ranker_path) DOC_TITLES = PROCESS_FULL_DOC_DB.get_doc_titles()
def init(top_k, get_text): global PROCESS_DB, PROCESS_RANKER, TOP_K, GET_TEXT PROCESS_DB = DocDB() Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) PROCESS_RANKER = TfidfDocRanker() TOP_K = top_k GET_TEXT = get_text
def main(): # todo: add scores to questions & paragraphs parser = argparse.ArgumentParser("Preprocess Hotpot Questions") parser.add_argument("--train_file", default=config.HOTPOT_TRAIN_FILE) parser.add_argument("--dev_file", default=config.HOTPOT_DEV_DISTRACTOR_FILE) parser.add_argument("--doc_db", default=None) parser.add_argument('--num-workers', type=int, default=1, help='Number of CPU processes') if not exists(join(config.CORPUS_DIR, 'hotpot')): mkdir(join(config.CORPUS_DIR, 'hotpot')) args = parser.parse_args() # target_dir = config.CORPUS_DIR # if exists(target_dir) and len(listdir(target_dir)) > 0: # raise ValueError("Files already exist in " + target_dir) if args.num_workers > 1: print(f"Multiprocessing with {args.num_workers} threads...") print("Parsing train...") train = parse_hotpot_data_async(args.train_file, args.doc_db, args.num_workers) print("Parsing dev...") dev = parse_hotpot_data_async(args.dev_file, args.doc_db, args.num_workers) else: tokenzier = CoreNLPTokenizer() docdb = DocDB(args.doc_db) print("Parsing train...") train = parse_hotpot_data(args.train_file, tokenzier, docdb) print("Parsing dev...") dev = parse_hotpot_data(args.dev_file, tokenzier, docdb) print("Saving...") HotpotQuestions.make_corpus(train, dev) print("Done")
def get_count_matrix(): """Form a sparse word to document count matrix (inverted index). M[i, j] = # times word i appears in document j. """ # Map doc_ids to indexes global DOC2IDX with DocDB(args.db_path) as doc_db: doc_titles = doc_db.get_doc_titles() DOC2IDX = {doc_title: i for i, doc_title in enumerate(doc_titles)} # Setup worker pool workers = ProcessPool(args.num_workers, initializer=init, initargs=[args.db_path]) # Compute the count matrix in steps (to keep in memory) logger.info('Mapping...') row, col, data = [], [], [] step = max(int(len(doc_titles) / 10), 1) batches = [doc_titles[i:i + step] for i in range(0, len(doc_titles), step)] _count = partial(count, args.ngram, args.hash_size) for i, batch in enumerate(batches): logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25) for b_row, b_col, b_data in workers.imap_unordered(_count, batch): row.extend(b_row) col.extend(b_col) data.extend(b_data) workers.close() workers.join() logger.info('Creating sparse matrix...') count_matrix = sp.csr_matrix((data, (row, col)), shape=(args.hash_size, len(doc_titles))) count_matrix.sum_duplicates() return count_matrix, (DOC2IDX, doc_titles)
def init(): global PROCESS_TOK, PROCESS_DB PROCESS_TOK = CoreNLPTokenizer() Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) PROCESS_DB = DocDB() Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
def encode_from_file(docs_file, questions_file, encodings_dir, encoder_model, num_workers, hotpot: bool, long_batch: int, short_batch: int, use_chars: bool, use_ema: bool, checkpoint: str, document_chunk_size=1000, samples=None, encode_all_db=False): """ :param out_file: .npz file to dump the encodings :param docs_file: path to json file whose structure is [{title: list of paragraphs}, ...] :return: """ doc_encs_handler = DocumentEncodingHandler(encodings_dir) # Setup worker pool workers = ProcessPool(num_workers, initializer=init, initargs=[]) if docs_file is not None: with open(docs_file, 'r') as f: documents = json.load(f) documents = { k: v for k, v in documents.items() if k not in doc_encs_handler.titles2filenames } tokenized_documents = {} tupled_doc_list = [(title, pars) for title, pars in documents.items()] if samples is not None: print(f"sampling {samples} samples") tupled_doc_list = tupled_doc_list[:samples] print("Tokenizing from file...") with tqdm(total=len(tupled_doc_list), ncols=80) as pbar: for tok_doc in tqdm( workers.imap_unordered(tokenize_document, tupled_doc_list)): tokenized_documents.update(tok_doc) pbar.update() else: if questions_file is not None: with open(questions_file, 'r') as f: questions = json.load(f) all_titles = list( set([title for q in questions for title in q['top_titles']])) else: print("encoding all DB!") all_titles = DocDB().get_doc_titles() if samples is not None: print(f"sampling {samples} samples") all_titles = all_titles[:samples] all_titles = [ t for t in all_titles if t not in doc_encs_handler.titles2filenames ] tokenized_documents = {} print("Tokenizing from DB...") with tqdm(total=len(all_titles), ncols=80) as pbar: for tok_doc in tqdm( workers.imap_unordered(tokenize_from_db, all_titles)): tokenized_documents.update(tok_doc) pbar.update() workers.close() workers.join() voc = set() for paragraphs in tokenized_documents.values(): for par in paragraphs: voc.update(par) if not hotpot: spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=1, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderSingleContext(model_dir_path=encoder_model, vocabulary=voc, spec=spec, loader=ResourceLoader(), use_char_inputs=use_chars, use_ema=use_ema, checkpoint=checkpoint) else: spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model, vocabulary=voc, spec=spec, loader=ResourceLoader(), use_char_inputs=use_chars, use_ema=use_ema, checkpoint=checkpoint) tokenized_documents_items = list(tokenized_documents.items()) for tokenized_doc_chunk in tqdm([ tokenized_documents_items[i:i + document_chunk_size] for i in range(0, len(tokenized_documents_items), document_chunk_size) ], ncols=80): flattened_pars_with_names = [(f"{title}_{i}", par) for title, pars in tokenized_doc_chunk for i, par in enumerate(pars)] # filtering out empty paragraphs (probably had some short string the tokenization removed) # important to notice that the filtered paragraphs will have no representation, # but they still exist in the numbering of paragraphs for consistency with the docs. flattened_pars_with_names = [(name, par) for name, par in flattened_pars_with_names if len(par) > 0] # sort such that longer paragraphs are first to identify OOMs early on flattened_pars_with_names = sorted(flattened_pars_with_names, key=lambda x: len(x[1]), reverse=True) long_paragraphs_ids = [ i for i, name_par in enumerate(flattened_pars_with_names) if len(name_par[1]) >= 900 ] short_paragraphs_ids = [ i for i, name_par in enumerate(flattened_pars_with_names) if len(name_par[1]) < 900 ] # print(f"Encoding {len(flattened_pars_with_names)} paragraphs...") name2enc = {} dummy_question = "Hello Hello".split() if not hotpot: model_paragraphs = [ BinaryQuestionAndParagraphs(question=dummy_question, paragraphs=[x], label=1, num_distractors=0, question_id='dummy') for _, x in flattened_pars_with_names ] else: # todo allow precomputed sentence segments model_paragraphs = [ IterativeQuestionAndParagraphs(question=dummy_question, paragraphs=[x, dummy_question], first_label=1, second_label=1, question_id='dummy', sentence_segments=None) for _, x in flattened_pars_with_names ] # print("Encoding long paragraphs...") long_pars = [model_paragraphs[i] for i in long_paragraphs_ids] name2enc.update({ flattened_pars_with_names[long_paragraphs_ids[i]][0]: enc for i, enc in enumerate( encoder.encode_paragraphs( long_pars, batch_size=long_batch, show_progress=True ) if not hotpot else encoder.encode_first_paragraphs( long_pars, batch_size=long_batch, show_progress=True)) }) # print("Encoding short paragraphs...") short_pars = [model_paragraphs[i] for i in short_paragraphs_ids] name2enc.update({ flattened_pars_with_names[short_paragraphs_ids[i]][0]: enc for i, enc in enumerate( encoder.encode_paragraphs( short_pars, batch_size=short_batch, show_progress=True ) if not hotpot else encoder.encode_first_paragraphs( short_pars, batch_size=short_batch, show_progress=True)) }) doc_encs_handler.save_multiple_documents(name2enc)
Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) def tokenize_words(text): global PROCESS_TOK return PROCESS_TOK.tokenize(text).words() def tokenize_sentences(sentences): global PROCESS_TOK return [PROCESS_TOK.tokenize(s).words() if s != '' else [] for s in sentences] print("Loading TF-IDF...") tfidf_ranker = TfidfDocRanker() db = DocDB() loader = ResourceLoader() # loader = HotpotQuestions().get_resource_loader() word_counts = load_counts(join(LOCAL_DATA_DIR, 'hotpot', 'wiki_word_counts.txt')) title_counts = load_counts(join(LOCAL_DATA_DIR, 'hotpot', 'wiki_title_word_counts.txt')) word_counts.update(title_counts) voc = set(word_counts.keys()) print("Loading encoder...") spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderIterativeModel(model_dir_path=args.encoder_model, vocabulary=voc, spec=spec, loader=loader, use_char_inputs=False, use_ema=not args.no_ema,
def modified_main(dataset_path, k_list_to_check, ranker_path=None, normalize_ranker=False, num_workers=1, tokenizer='corenlp', docdb_path=None, out=None): dataset = load_dataset(dataset_path) ranker = TfidfDocRanker(tfidf_path=ranker_path, normalize_vectors=normalize_ranker, tokenizer=tokenizer) docdb = DocDB(docdb_path) print("Building modified queries...") ranked_gold_dict = build_ranked_golds(dataset, docdb=docdb, ranker=ranker) regular_table = prettytable.PrettyTable([ 'Top K', 'Second Paragraph Hits', 'Second Paragraph Hits Modified Query' ]) cat_table_dict = { cat: prettytable.PrettyTable([ 'Top K', 'Second Paragraph Hits', 'Second Paragraph Hits Modified Query' ]) for cat in CATEGORIES } max_k = max(k_list_to_check) print(f"Retrieving top {max_k} ...") start = time.time() reg_result_dict, ranked_result_dict = get_ranked_top_k( dataset, ranked_gold_dict, ranker, max_k, num_workers) print(f"Done, took {time.time()-start} ms.") for k in k_list_to_check: print(f"Calculating scores for top {k}...") start = time.time() reg_scores, reg_category_scores = modified_top_k_coverage_score( ranked_gold_dict, reg_result_dict, k) mod_scores, mod_category_scores = modified_top_k_coverage_score( ranked_gold_dict, ranked_result_dict, k) print(f"Done, took {time.time()-start} ms.") regular_table.add_row([ k, reg_scores['Second Paragraph Hits'], mod_scores['Second Paragraph Hits'] ]) for cat in cat_table_dict: cat_table_dict[cat].add_row([ k, reg_category_scores[cat]['Second Paragraph Hits'], mod_category_scores[cat]['Second Paragraph Hits'] ]) output_str = 'Overall Results:\n' output_str += regular_table.__str__() + '\n' for cat, table in cat_table_dict.items(): output_str += '\n**********************************************\n' output_str += f"Category: {cat} Results:\n" output_str += table.__str__() + '\n' if out is None: print(output_str) else: with open(out, 'w') as f: f.write(output_str)
def build_openqa_iterative_top_titles( base_dir, questions_file, docs_file, encodings_dir, encoder_model, k1_list: List[int], k2_list: List[int], n1_list: List[int], n2_list: List[int], evaluate: bool, reformulate_from_text: bool, use_ema: bool, checkpoint: str, safety_mult: int): print('Loading data...') s = time.time() with open(questions_file, 'r') as f: questions = json.load(f) if docs_file is not None: with open(docs_file, 'r') as f: documents = json.load(f) else: docs_db = DocDB() print(f'Done, took {time.time()-s} seconds.') if n1_list is not None and n2_list is not None: for q in questions: q['top_titles'] = q['top_titles'][:max(max(n1_list), max(n2_list))] # Setup worker pool workers = ProcessPool(16, initializer=init, initargs=[]) qid2tokenized = {} tupled_questions = [(q['qid'], q['question']) for q in questions] print("Tokenizing questions...") with tqdm(total=len(tupled_questions)) as pbar: for tok_q in tqdm( workers.imap_unordered(tokenize_question, tupled_questions)): qid2tokenized.update(tok_q) pbar.update() voc = set() for question in qid2tokenized.values(): voc.update(question) workers.close() workers.join() # all_titles = list(set([title for q in questions for title in q['top_titles']])) def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) if docs_file is not None: return documents[par_title][par_num] return ' '.join(docs_db.get_doc_sentences(par_title)) # print(f"Gathering documents...") # Setup worker pool workers = ProcessPool(16, initializer=init_encoding_handler, initargs=[encodings_dir]) # title2encs = {} # title2idx2par_name = {} # with tqdm(total=len(all_titles)) as pbar: # for t2enc, t2id2p in tqdm(workers.imap_unordered(get_title_mappings_from_saver, all_titles)): # title2encs.update(t2enc) # title2idx2par_name.update(t2id2p) # pbar.update() # title2par_name2idxs = {} # for title, id2par in title2idx2par_name.items(): # par2idxs = {} # for idx, parname in id2par.items(): # if parname in par2idxs: # par2idxs[parname].append(idx) # else: # par2idxs[parname] = [idx] # title2par_name2idxs[title] = {par: sorted(idxs) for par, idxs in par2idxs.items()} print("Loading encoder...") spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model, vocabulary=voc, spec=spec, loader=ResourceLoader(), use_char_inputs=False, use_ema=use_ema, checkpoint=checkpoint) print("Encoding questions...") q_original_encodings = encoder.encode_text_questions( [qid2tokenized[q['qid']] for q in questions], return_search_vectors=False, show_progress=True) q_search_encodings = encoder.question_rep_to_search_vector( question_encodings=q_original_encodings) init() # for initializing the tokenizer total_num = len(n1_list) * len(n2_list) * len(k1_list) * len(k2_list) print("Calculating similarities...") for n1, n2, k1, k2 in tqdm(itertools.product(n1_list, n2_list, k1_list, k2_list), total=total_num, ncols=80): questions = iterative_retrieval(encoder, questions, qid2tokenized, q_search_encodings, workers, parname_to_text, reformulate_from_text, n1, n2, k1, k2, safety_mult) dir_path = os.path.join(base_dir, f"n2-{n2}", f"n1-{n1}") os.makedirs(dir_path, exist_ok=True) out_file = os.path.join(dir_path, f"n1-{n1}_n2-{n2}_k1-{k1}_k2-{k2}.json") questions_copy = deepcopy(questions) for question in questions_copy: question.pop('top_titles') with open(out_file, 'w') as f: json.dump(questions_copy, f) if evaluate: eval_questions(questions_copy)
def main(): parser = argparse.ArgumentParser( description='Full ranking evaluation on Hotpot') parser.add_argument('model', help='model directory to evaluate') parser.add_argument( 'output', type=str, help="Store the per-paragraph results in csv format in this file, " "or the json prediction if in test mode") parser.add_argument('-n', '--sample_questions', type=int, default=None, help="(for testing) run on a subset of questions") parser.add_argument( '-b', '--batch_size', type=int, default=64, help="Batch size, larger sizes can be faster but uses more memory") parser.add_argument( '-s', '--step', default=None, help="Weights to load, can be a checkpoint step or 'latest'") parser.add_argument('-a', '--answer_bound', type=int, default=8, help="Max answer span length") parser.add_argument('-c', '--corpus', choices=[ "distractors", "gold", "hotpot_file", "retrieval_file", "top_titles" ], default="distractors") parser.add_argument('-t', '--tokens', type=int, default=None, help="Max tokens per a paragraph") parser.add_argument('--input_file', type=str, default=None) parser.add_argument('--docs_file', type=str, default=None) parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for tokenizing') parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist") parser.add_argument('--no_sp', action="store_true", help="Don't predict supporting facts") parser.add_argument('--test_mode', action='store_true', help="produce a prediction file, no answers given") args = parser.parse_args() model_dir = ModelDir(args.model) batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True) loader = ResourceLoader() if args.corpus not in {"distractors", "gold"} and args.input_file is None: raise ValueError( "Must pass an input file if not using precomputed dataset") if args.corpus in {"distractors", "gold"} and args.test_mode: raise ValueError( "Test mode not available in 'distractors' or 'gold' mode") if args.corpus in {"distractors", "gold"}: corpus = HotpotQuestions() loader = corpus.get_resource_loader() questions = corpus.get_dev() question_preprocessor = HotpotTextLengthPreprocessorWithSpans( args.tokens) questions = [ question_preprocessor.preprocess(x) for x in questions if (question_preprocessor.preprocess(x) is not None) ] if args.sample_questions: np.random.RandomState(0).shuffle( sorted(questions, key=lambda x: x.question_id)) questions = questions[:args.sample_questions] data = HotpotFullQADistractorsDataset(questions, batcher) gold_idxs = set(data.gold_idxs) if args.corpus == 'gold': data.samples = [data.samples[i] for i in data.gold_idxs] qid2samples = {} qid2idx = {} for i, sample in enumerate(data.samples): key = sample.question_id if key in qid2samples: qid2samples[key].append(sample) qid2idx[key].append(i) else: qid2samples[key] = [sample] qid2idx[key] = [i] questions = [] print("Ranking pairs...") gold_ranks = [] for qid, samples in tqdm(qid2samples.items()): question = " ".join(samples[0].question) pars = [" ".join(x.paragraphs[0]) for x in samples] ranks = get_paragraph_ranks(question, pars) for sample, rank, idx in zip(samples, ranks, qid2idx[qid]): questions.append( RankedQAPair(question=sample.question, paragraphs=sample.paragraphs, spans=np.zeros((0, 2), dtype=np.int32), question_id=sample.question_id, answer=sample.answer, rank=rank, q_type=sample.q_type, sentence_segments=sample.sentence_segments)) if idx in gold_idxs: gold_ranks.append(rank + 1) print(f"Mean rank: {np.mean(gold_ranks)}") ranks_counter = Counter(gold_ranks) for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]: print(f"Hits at {i}: {ranks_counter[i]}") elif args.corpus == 'hotpot_file': # a hotpot json format input file. We rank the pairs with tf-idf with open(args.input_file, 'r') as f: hotpot_data = json.load(f) if args.sample_questions: np.random.RandomState(0).shuffle( sorted(hotpot_data, key=lambda x: x['_id'])) hotpot_data = hotpot_data[:args.sample_questions] title2sentences = { context[0]: context[1] for q in hotpot_data for context in q['context'] } question_tok_texts = tokenize_texts( [q['question'] for q in hotpot_data], num_workers=args.num_workers) sentences_tok = tokenize_texts(list(title2sentences.values()), num_workers=args.num_workers, sentences=True) if args.tokens is not None: sentences_tok = [ truncate_paragraph(p, args.tokens) for p in sentences_tok ] title2tok_sents = { title: sentences for title, sentences in zip(title2sentences.keys(), sentences_tok) } questions = [] for idx, question in enumerate(tqdm(hotpot_data, desc='tf-idf ranking')): q_titles = [title for title, _ in question['context']] par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles) for title2 in q_titles[i + 1:]] if len(par_pairs) == 0: continue ranks = get_paragraph_ranks(question['question'], [ ' '.join(title2sentences[t1] + title2sentences[t2]) for t1, t2 in par_pairs ]) for rank, par_pair in zip(ranks, par_pairs): sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[ par_pair[1]] sentence_segments, _ = get_segments_from_sentences_fix_sup( sent_tok_pair, np.zeros(0)) missing_sent_idx = [[ i for i, sent in enumerate(title2tok_sents[title]) if len(sent) == 0 ] for title in par_pair] questions.append( RankedQAPair( question=question_tok_texts[idx], paragraphs=[flatten_iterable(sent_tok_pair)], spans=np.zeros((0, 2), dtype=np.int32), question_id=question['_id'], answer='noanswer' if args.test_mode else question['answer'], rank=rank, q_type='null' if args.test_mode else question['type'], sentence_segments=[sentence_segments], par_titles_num_sents=[ (title, sum(1 for sent in title2tok_sents[title] if len(sent) > 0)) for title in par_pair ], missing_sent_idxs=missing_sent_idx, true_sp=[] if args.test_mode else question['supporting_facts'])) elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles': if args.docs_file is None: print("Using DB documents") doc_db = DocDB(config.DOC_DB, full_docs=False) else: with open(args.docs_file, 'r') as f: docs = json.load(f) with open(args.input_file, 'r') as f: retrieval_data = json.load(f) if args.sample_questions: np.random.RandomState(0).shuffle( sorted(retrieval_data, key=lambda x: x['qid'])) retrieval_data = retrieval_data[:args.sample_questions] def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) if args.docs_file is None: return doc_db.get_doc_sentences(par_title) return docs[par_title][par_num] if args.corpus == 'top_titles': print("Top TF-IDF!") for q in retrieval_data: top_titles = q['top_titles'][:10] q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0') for i, title1 in enumerate(top_titles) for title2 in top_titles[i + 1:]] question_tok_texts = tokenize_texts( [q['question'] for q in retrieval_data], num_workers=args.num_workers) all_parnames = list( set([ parname for q in retrieval_data for pair in q['paragraph_pairs'] for parname in pair ])) texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames], num_workers=args.num_workers, sentences=True) if args.tokens is not None: texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok] parname2tok_text = { parname: text for parname, text in zip(all_parnames, texts_tok) } questions = [] for idx, question in enumerate(retrieval_data): for rank, par_pair in enumerate(question['paragraph_pairs']): tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[ par_pair[1]] sentence_segments, _ = get_segments_from_sentences_fix_sup( tok_pair, np.zeros(0)) missing_sent_idx = [[ i for i, sent in enumerate(parname2tok_text[parname]) if len(sent) == 0 ] for parname in par_pair] questions.append( RankedQAPair( question=question_tok_texts[idx], paragraphs=[flatten_iterable(tok_pair)], spans=np.zeros((0, 2), dtype=np.int32), question_id=question['qid'], answer='noanswer' if args.test_mode else question['answers'][0], rank=rank, q_type='null' if args.test_mode else question['type'], sentence_segments=[sentence_segments], par_titles_num_sents=[ (par_name_to_title(parname), sum(1 for sent in parname2tok_text[parname] if len(sent) > 0)) for parname in par_pair ], missing_sent_idxs=missing_sent_idx, true_sp=[] if args.test_mode else question['supporting_facts'])) else: raise NotImplementedError() data = DummyDataset(questions, batcher) evaluators = [ RecordHotpotQAPrediction(args.answer_bound, True, sp_prediction=not args.no_sp) ] if args.step is not None: if args.step == "latest": checkpoint = model_dir.get_latest_checkpoint() else: checkpoint = model_dir.get_checkpoint(int(args.step)) else: checkpoint = model_dir.get_best_weights() if checkpoint is not None: print("Using best weights") else: print("Using latest checkpoint") checkpoint = model_dir.get_latest_checkpoint() model = model_dir.get_model() evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader, checkpoint, not args.no_ema, 10)[args.corpus] print("Saving result") output_file = args.output df = pd.DataFrame(evaluation.per_sample) df.sort_values(["question_id", "rank"], inplace=True, ascending=True) group_by = ["question_id"] def get_ranked_scores(score_name): filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \ df[df.type == 'bridge'] if "Br" in score_name else df target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text' target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}" return compute_ranked_scores_with_yes_no( filtered_df, span_q_col="span_question_scores", yes_no_q_col="yes_no_question_scores", yes_no_scores_col="yes_no_confidence_scores", span_scores_col="predicted_score", span_target_score=target_score, group_cols=group_by) if not args.test_mode: score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"] if not args.no_sp: score_names.extend([ f"{prefix} {name}" for prefix in ['sp', 'joint'] for name in score_names ]) table = [["N Paragraphs"] + score_names] scores = [get_ranked_scores(score_name) for score_name in score_names] table += list([str(i + 1), *["%.4f" % x for x in score_vals]] for i, score_vals in enumerate(zip(*scores))) print_table(table) df.to_csv(output_file, index=False) else: df_to_pred(df, output_file)