def eval_questions(questions, top_k=None, specific_ks=None): k_scores = {} bridge_k_scores = {} comparision_k_scores = {} for question in questions: q_stats = get_stats_for_sample( set([x[0] for x in question['supporting_facts']]), sorted_title_pairs=[[par_name_to_title(x) for x in pair] for pair in question['paragraph_pairs']], top_k=top_k) type_scores = bridge_k_scores if question[ 'type'] == 'bridge' else comparision_k_scores for k in q_stats: for score_dict in [k_scores, type_scores]: if k not in score_dict: score_dict[k] = { key: [val] for key, val in q_stats[k].items() } else: for key, val in q_stats[k].items(): score_dict[k][key].append(val) for score_dict in [k_scores, bridge_k_scores, comparision_k_scores]: for k in score_dict.keys(): score_dict[k] = { key: np.mean(val) for key, val in score_dict[k].items() } for score_dict, name in [(bridge_k_scores, 'Bridge'), (comparision_k_scores, 'Comparison'), (k_scores, 'Overall')]: results = prettytable.PrettyTable( ['Top K Pairs', 'Hits', 'Perfect Questions', 'At Least One']) for k, k_dict in score_dict.items(): results.add_row( [k, k_dict['hits'], k_dict['perfect'], k_dict['at_least_one']]) results.sortby = 'Top K Pairs' print(f"{name} scores:") print(results) print('\n**********************************************\n') if specific_ks is not None: for score_dict, name in [(bridge_k_scores, 'Bridge'), (comparision_k_scores, 'Comparison'), (k_scores, 'Overall')]: results = prettytable.PrettyTable( ['Top K Pairs', 'Hits', 'Perfect Questions', 'At Least One']) for k, k_dict in [(k, v) for k, v in score_dict.items() if k in specific_ks]: results.add_row([ k, k_dict['hits'], k_dict['perfect'], k_dict['at_least_one'] ]) print(f"{name} scores:") print(results) print('\n**********************************************\n')
def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) return documents[par_title][par_num]
def build_openqa_iterative_top_titles(out_file, questions_file, docs_file, encodings_dir, encoder_model, k1, k2, n1, n2, evaluate: bool, reformulate_from_text: bool, use_ema: bool, checkpoint: str): print('Loading data...') s = time.time() with open(questions_file, 'r') as f: questions = json.load(f) with open(docs_file, 'r') as f: documents = json.load(f) print(f'Done, took {time.time()-s} seconds.') if n1 is not None and n2 is not None: for q in questions: q['top_titles'] = q['top_titles'][:max(n1, n2)] # Setup worker pool workers = ProcessPool(16, initializer=init, initargs=[]) qid2tokenized = {} tupled_questions = [(q['qid'], q['question']) for q in questions] print("Tokenizing questions...") with tqdm(total=len(tupled_questions)) as pbar: for tok_q in tqdm( workers.imap_unordered(tokenize_question, tupled_questions)): qid2tokenized.update(tok_q) pbar.update() voc = set() for question in qid2tokenized.values(): voc.update(question) all_titles = list( set([title for q in questions for title in q['top_titles']])) def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) return documents[par_title][par_num] print(f"Gathering documents...") # Setup worker pool workers = ProcessPool(32, initializer=init_encoding_handler, initargs=[encodings_dir]) title2encs = {} title2idx2par_name = {} with tqdm(total=len(all_titles)) as pbar: for t2enc, t2id2p in tqdm( workers.imap_unordered(get_title_mappings_from_saver, all_titles)): title2encs.update(t2enc) title2idx2par_name.update(t2id2p) pbar.update() title2par_name2idxs = {} for title, id2par in title2idx2par_name.items(): par2idxs = {} for idx, parname in id2par.items(): if parname in par2idxs: par2idxs[parname].append(idx) else: par2idxs[parname] = [idx] title2par_name2idxs[title] = { par: sorted(idxs) for par, idxs in par2idxs.items() } print("Loading encoder...") spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model, vocabulary=voc, spec=spec, loader=ResourceLoader(), use_char_inputs=False, use_ema=use_ema, checkpoint=checkpoint) print("Encoding questions...") q_original_encodings = encoder.encode_text_questions( [qid2tokenized[q['qid']] for q in questions], return_search_vectors=False, show_progress=True) q_search_encodings = encoder.question_rep_to_search_vector( question_encodings=q_original_encodings) init() # for initializing the tokenizer print("Calculating similarities...") for idx, question in tqdm(enumerate(questions), total=len(questions)): title2ids = {} all_par_reps = [] total_sentences = 0 titles_offset_dict = {} for title in question['top_titles'][:n1]: titles_offset_dict[title] = total_sentences rep = title2encs[title] title2ids[title] = list( range(total_sentences, total_sentences + len(rep))) all_par_reps.append(rep) total_sentences += len(rep) id2title = {i: title for title, ids in title2ids.items() for i in ids} all_par_reps = np.concatenate(all_par_reps, axis=0) q_enc = q_search_encodings[idx] top_k = simple_numpy_knn(np.expand_dims(q_enc, 0), all_par_reps, k1 * 2)[0] def id_to_par_name(rep_id): return title2idx2par_name[id2title[rep_id]][ rep_id - titles_offset_dict[id2title[rep_id]]] seen = set() p_names = [ id_to_par_name(x) for x in top_k if not (id_to_par_name(x) in seen or seen.add(id_to_par_name(x))) ][:k1] iteration1_paragraphs = \ [title2encs[par_name_to_title(pname)][title2par_name2idxs[par_name_to_title(pname)][pname], :] for pname in p_names] if not reformulate_from_text: reformulations = encoder.reformulate_questions( questions_rep=np.tile(q_original_encodings[idx], reps=(len(p_names), 1)), paragraphs_rep=iteration1_paragraphs, return_search_vectors=True) else: tok_q = tokenize(question['question']).words() par_texts = [ tokenize(parname_to_text(pname)).words() for pname in p_names ] reformulations = encoder.reformulate_questions_from_texts( tokenized_questions=[tok_q for _ in range(len(par_texts))], tokenized_pars=par_texts, return_search_vectors=True) title2ids = {} all_par_reps = [] total_sentences = 0 titles_offset_dict = {} for title in question['top_titles'][:n2]: titles_offset_dict[title] = total_sentences rep = title2encs[title] title2ids[title] = list( range(total_sentences, total_sentences + len(rep))) all_par_reps.append(rep) total_sentences += len(rep) id2title = {i: title for title, ids in title2ids.items() for i in ids} all_par_reps = np.concatenate(all_par_reps, axis=0) top_k_second = numpy_global_knn(reformulations, all_par_reps, k2 * k1) seen = set() final_p_name_pairs = [ (p_names[x1], id_to_par_name(x2)) for x1, x2 in top_k_second if not ((p_names[x1], id_to_par_name(x2)) in seen or seen.add( (p_names[x1], id_to_par_name(x2)))) ][:k2] # important to note that in the iterative dataset the paragraphs of each question are in pairs question['paragraph_pairs'] = final_p_name_pairs with open(out_file, 'w') as f: json.dump(questions, f) if evaluate: eval_questions(questions)
def reformulation_retrieval(encoder, workers, questions: List, doc_db: DocDB, k2: int, n2: int, safety_mult: int = 1): def title_to_text(title): return ' '.join(doc_db.get_doc_sentences(title)) tokenized_qs = [ tok_q.words() for tok_q in workers.imap(tokenize, [q['question'] for q in questions]) ] par_texts = [ x for x in workers.imap(tokenize_and_concat, [[title_to_text(title) for title in titles] for q in questions for titles in q['top_pars_titles']]) ] pnames_end_idxs = list( itertools.accumulate([len(q['top_pars_titles']) for q in questions])) q_with_p = list( zip([ tokenized_qs[idx] for idx, q in enumerate(questions) for _ in q['top_pars_titles'] ], par_texts)) q_with_p = [(x, i) for i, x in enumerate(q_with_p)] sorted_q_with_p = sorted(q_with_p, key=lambda x: (len(x[0][1]), len(x[0][0]), x[1]), reverse=True) sorted_qs, sorted_ps = zip(*[x for x, _ in sorted_q_with_p]) last_long_index = max( [i for i, x in enumerate(sorted_ps) if len(x) >= 900] + [-1]) if last_long_index != -1: reformulations_long = encoder.reformulate_questions_from_texts( tokenized_questions=sorted_qs[:last_long_index + 1], tokenized_pars=sorted_ps[:last_long_index + 1], return_search_vectors=True, show_progress=False, max_batch=8) reformulations_short = encoder.reformulate_questions_from_texts( tokenized_questions=sorted_qs[last_long_index + 1:], tokenized_pars=sorted_ps[last_long_index + 1:], return_search_vectors=True, show_progress=False, max_batch=64) reformulations = np.concatenate( [reformulations_long, reformulations_short], axis=0) else: reformulations = encoder.reformulate_questions_from_texts( tokenized_questions=sorted_qs, tokenized_pars=sorted_ps, return_search_vectors=True, show_progress=False, max_batch=64) reformulations = reformulations[np.argsort([i for _, i in sorted_q_with_p])] pnames_end_idxs = [0] + pnames_end_idxs reformulations_per_question = [ reformulations[pnames_end_idxs[i]:pnames_end_idxs[i + 1]] for i in range(len(questions)) ] for q_idx, question in enumerate(questions): q_titles = question['top_titles'][:n2] title2encs = {} title2idx2par_name = {} for t2enc, t2id2p in workers.imap_unordered( get_title_mappings_from_saver, q_titles): title2encs.update(t2enc) title2idx2par_name.update(t2id2p) title2par_name2idxs = {} for title, id2par in title2idx2par_name.items(): par2idxs = {} for idx, parname in id2par.items(): if parname in par2idxs: par2idxs[parname].append(idx) else: par2idxs[parname] = [idx] title2par_name2idxs[title] = { par: sorted(idxs) for par, idxs in par2idxs.items() } title2ids = {} all_par_reps = [] total_sentences = 0 titles_offset_dict = {} for title in question['top_titles'][:n2]: titles_offset_dict[title] = total_sentences rep = title2encs[title] title2ids[title] = list( range(total_sentences, total_sentences + len(rep))) all_par_reps.append(rep) total_sentences += len(rep) id2title = {i: title for title, ids in title2ids.items() for i in ids} all_par_reps = np.concatenate(all_par_reps, axis=0) def id_to_par_name(rep_id): return title2idx2par_name[id2title[rep_id]][ rep_id - titles_offset_dict[id2title[rep_id]]] top_k_second = numpy_global_knn(reformulations_per_question[q_idx], all_par_reps, k2 * safety_mult) seen = set() p_names = question['top_pars_titles'] final_p_name_pairs = [ (*p_names[x1], par_name_to_title(id_to_par_name(x2))) for x1, x2 in top_k_second if not ( (*p_names[x1], par_name_to_title(id_to_par_name(x2))) in seen or seen.add( (*p_names[x1], par_name_to_title(id_to_par_name(x2))))) ][:k2] # important to note that in the iterative dataset the paragraphs of each question are in pairs question['top_pars_titles'] = final_p_name_pairs return questions
def initial_retrieval(encoder, workers, questions: List, k1: int, n1: int, safety_mult: int = 1): tokenized_qs = [ tok_q.words() for tok_q in workers.imap(tokenize, [q['question'] for q in questions]) ] q_search_encodings = encoder.encode_text_questions( tokenized_qs, return_search_vectors=True, show_progress=False) for q_idx, question in enumerate(questions): q_titles = question['top_titles'][:n1] title2encs = {} title2idx2par_name = {} for t2enc, t2id2p in workers.imap_unordered( get_title_mappings_from_saver, q_titles): title2encs.update(t2enc) title2idx2par_name.update(t2id2p) title2par_name2idxs = {} for title, id2par in title2idx2par_name.items(): par2idxs = {} for idx, parname in id2par.items(): if parname in par2idxs: par2idxs[parname].append(idx) else: par2idxs[parname] = [idx] title2par_name2idxs[title] = { par: sorted(idxs) for par, idxs in par2idxs.items() } title2ids = {} all_par_reps = [] total_sentences = 0 titles_offset_dict = {} for title in question['top_titles'][:n1]: titles_offset_dict[title] = total_sentences rep = title2encs[title] title2ids[title] = list( range(total_sentences, total_sentences + len(rep))) all_par_reps.append(rep) total_sentences += len(rep) id2title = {i: title for title, ids in title2ids.items() for i in ids} all_par_reps = np.concatenate(all_par_reps, axis=0) def id_to_par_name(rep_id): return title2idx2par_name[id2title[rep_id]][ rep_id - titles_offset_dict[id2title[rep_id]]] q_enc = q_search_encodings[q_idx] top_k = simple_numpy_knn(np.expand_dims(q_enc, 0), all_par_reps, k1 * safety_mult)[0] seen = set() p_names = [ id_to_par_name(x) for x in top_k if not (id_to_par_name(x) in seen or seen.add(id_to_par_name(x))) ][:k1] question['top_pars_titles'] = [(par_name_to_title(p), ) for p in p_names] return questions
def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) if docs_file is not None: return documents[par_title][par_num] return ' '.join(docs_db.get_doc_sentences(par_title))
def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) if args.docs_file is None: return doc_db.get_doc_sentences(par_title) return docs[par_title][par_num]
def main(): parser = argparse.ArgumentParser( description='Full ranking evaluation on Hotpot') parser.add_argument('model', help='model directory to evaluate') parser.add_argument( 'output', type=str, help="Store the per-paragraph results in csv format in this file, " "or the json prediction if in test mode") parser.add_argument('-n', '--sample_questions', type=int, default=None, help="(for testing) run on a subset of questions") parser.add_argument( '-b', '--batch_size', type=int, default=64, help="Batch size, larger sizes can be faster but uses more memory") parser.add_argument( '-s', '--step', default=None, help="Weights to load, can be a checkpoint step or 'latest'") parser.add_argument('-a', '--answer_bound', type=int, default=8, help="Max answer span length") parser.add_argument('-c', '--corpus', choices=[ "distractors", "gold", "hotpot_file", "retrieval_file", "top_titles" ], default="distractors") parser.add_argument('-t', '--tokens', type=int, default=None, help="Max tokens per a paragraph") parser.add_argument('--input_file', type=str, default=None) parser.add_argument('--docs_file', type=str, default=None) parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for tokenizing') parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist") parser.add_argument('--no_sp', action="store_true", help="Don't predict supporting facts") parser.add_argument('--test_mode', action='store_true', help="produce a prediction file, no answers given") args = parser.parse_args() model_dir = ModelDir(args.model) batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True) loader = ResourceLoader() if args.corpus not in {"distractors", "gold"} and args.input_file is None: raise ValueError( "Must pass an input file if not using precomputed dataset") if args.corpus in {"distractors", "gold"} and args.test_mode: raise ValueError( "Test mode not available in 'distractors' or 'gold' mode") if args.corpus in {"distractors", "gold"}: corpus = HotpotQuestions() loader = corpus.get_resource_loader() questions = corpus.get_dev() question_preprocessor = HotpotTextLengthPreprocessorWithSpans( args.tokens) questions = [ question_preprocessor.preprocess(x) for x in questions if (question_preprocessor.preprocess(x) is not None) ] if args.sample_questions: np.random.RandomState(0).shuffle( sorted(questions, key=lambda x: x.question_id)) questions = questions[:args.sample_questions] data = HotpotFullQADistractorsDataset(questions, batcher) gold_idxs = set(data.gold_idxs) if args.corpus == 'gold': data.samples = [data.samples[i] for i in data.gold_idxs] qid2samples = {} qid2idx = {} for i, sample in enumerate(data.samples): key = sample.question_id if key in qid2samples: qid2samples[key].append(sample) qid2idx[key].append(i) else: qid2samples[key] = [sample] qid2idx[key] = [i] questions = [] print("Ranking pairs...") gold_ranks = [] for qid, samples in tqdm(qid2samples.items()): question = " ".join(samples[0].question) pars = [" ".join(x.paragraphs[0]) for x in samples] ranks = get_paragraph_ranks(question, pars) for sample, rank, idx in zip(samples, ranks, qid2idx[qid]): questions.append( RankedQAPair(question=sample.question, paragraphs=sample.paragraphs, spans=np.zeros((0, 2), dtype=np.int32), question_id=sample.question_id, answer=sample.answer, rank=rank, q_type=sample.q_type, sentence_segments=sample.sentence_segments)) if idx in gold_idxs: gold_ranks.append(rank + 1) print(f"Mean rank: {np.mean(gold_ranks)}") ranks_counter = Counter(gold_ranks) for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]: print(f"Hits at {i}: {ranks_counter[i]}") elif args.corpus == 'hotpot_file': # a hotpot json format input file. We rank the pairs with tf-idf with open(args.input_file, 'r') as f: hotpot_data = json.load(f) if args.sample_questions: np.random.RandomState(0).shuffle( sorted(hotpot_data, key=lambda x: x['_id'])) hotpot_data = hotpot_data[:args.sample_questions] title2sentences = { context[0]: context[1] for q in hotpot_data for context in q['context'] } question_tok_texts = tokenize_texts( [q['question'] for q in hotpot_data], num_workers=args.num_workers) sentences_tok = tokenize_texts(list(title2sentences.values()), num_workers=args.num_workers, sentences=True) if args.tokens is not None: sentences_tok = [ truncate_paragraph(p, args.tokens) for p in sentences_tok ] title2tok_sents = { title: sentences for title, sentences in zip(title2sentences.keys(), sentences_tok) } questions = [] for idx, question in enumerate(tqdm(hotpot_data, desc='tf-idf ranking')): q_titles = [title for title, _ in question['context']] par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles) for title2 in q_titles[i + 1:]] if len(par_pairs) == 0: continue ranks = get_paragraph_ranks(question['question'], [ ' '.join(title2sentences[t1] + title2sentences[t2]) for t1, t2 in par_pairs ]) for rank, par_pair in zip(ranks, par_pairs): sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[ par_pair[1]] sentence_segments, _ = get_segments_from_sentences_fix_sup( sent_tok_pair, np.zeros(0)) missing_sent_idx = [[ i for i, sent in enumerate(title2tok_sents[title]) if len(sent) == 0 ] for title in par_pair] questions.append( RankedQAPair( question=question_tok_texts[idx], paragraphs=[flatten_iterable(sent_tok_pair)], spans=np.zeros((0, 2), dtype=np.int32), question_id=question['_id'], answer='noanswer' if args.test_mode else question['answer'], rank=rank, q_type='null' if args.test_mode else question['type'], sentence_segments=[sentence_segments], par_titles_num_sents=[ (title, sum(1 for sent in title2tok_sents[title] if len(sent) > 0)) for title in par_pair ], missing_sent_idxs=missing_sent_idx, true_sp=[] if args.test_mode else question['supporting_facts'])) elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles': if args.docs_file is None: print("Using DB documents") doc_db = DocDB(config.DOC_DB, full_docs=False) else: with open(args.docs_file, 'r') as f: docs = json.load(f) with open(args.input_file, 'r') as f: retrieval_data = json.load(f) if args.sample_questions: np.random.RandomState(0).shuffle( sorted(retrieval_data, key=lambda x: x['qid'])) retrieval_data = retrieval_data[:args.sample_questions] def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) if args.docs_file is None: return doc_db.get_doc_sentences(par_title) return docs[par_title][par_num] if args.corpus == 'top_titles': print("Top TF-IDF!") for q in retrieval_data: top_titles = q['top_titles'][:10] q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0') for i, title1 in enumerate(top_titles) for title2 in top_titles[i + 1:]] question_tok_texts = tokenize_texts( [q['question'] for q in retrieval_data], num_workers=args.num_workers) all_parnames = list( set([ parname for q in retrieval_data for pair in q['paragraph_pairs'] for parname in pair ])) texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames], num_workers=args.num_workers, sentences=True) if args.tokens is not None: texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok] parname2tok_text = { parname: text for parname, text in zip(all_parnames, texts_tok) } questions = [] for idx, question in enumerate(retrieval_data): for rank, par_pair in enumerate(question['paragraph_pairs']): tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[ par_pair[1]] sentence_segments, _ = get_segments_from_sentences_fix_sup( tok_pair, np.zeros(0)) missing_sent_idx = [[ i for i, sent in enumerate(parname2tok_text[parname]) if len(sent) == 0 ] for parname in par_pair] questions.append( RankedQAPair( question=question_tok_texts[idx], paragraphs=[flatten_iterable(tok_pair)], spans=np.zeros((0, 2), dtype=np.int32), question_id=question['qid'], answer='noanswer' if args.test_mode else question['answers'][0], rank=rank, q_type='null' if args.test_mode else question['type'], sentence_segments=[sentence_segments], par_titles_num_sents=[ (par_name_to_title(parname), sum(1 for sent in parname2tok_text[parname] if len(sent) > 0)) for parname in par_pair ], missing_sent_idxs=missing_sent_idx, true_sp=[] if args.test_mode else question['supporting_facts'])) else: raise NotImplementedError() data = DummyDataset(questions, batcher) evaluators = [ RecordHotpotQAPrediction(args.answer_bound, True, sp_prediction=not args.no_sp) ] if args.step is not None: if args.step == "latest": checkpoint = model_dir.get_latest_checkpoint() else: checkpoint = model_dir.get_checkpoint(int(args.step)) else: checkpoint = model_dir.get_best_weights() if checkpoint is not None: print("Using best weights") else: print("Using latest checkpoint") checkpoint = model_dir.get_latest_checkpoint() model = model_dir.get_model() evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader, checkpoint, not args.no_ema, 10)[args.corpus] print("Saving result") output_file = args.output df = pd.DataFrame(evaluation.per_sample) df.sort_values(["question_id", "rank"], inplace=True, ascending=True) group_by = ["question_id"] def get_ranked_scores(score_name): filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \ df[df.type == 'bridge'] if "Br" in score_name else df target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text' target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}" return compute_ranked_scores_with_yes_no( filtered_df, span_q_col="span_question_scores", yes_no_q_col="yes_no_question_scores", yes_no_scores_col="yes_no_confidence_scores", span_scores_col="predicted_score", span_target_score=target_score, group_cols=group_by) if not args.test_mode: score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"] if not args.no_sp: score_names.extend([ f"{prefix} {name}" for prefix in ['sp', 'joint'] for name in score_names ]) table = [["N Paragraphs"] + score_names] scores = [get_ranked_scores(score_name) for score_name in score_names] table += list([str(i + 1), *["%.4f" % x for x in score_vals]] for i, score_vals in enumerate(zip(*scores))) print_table(table) df.to_csv(output_file, index=False) else: df_to_pred(df, output_file)
def iterative_retrieval(encoder, questions, q_original_encodings, q_search_encodings, workers, parname_to_text, reformulate_from_text, n1, n2, k1, k2, safety_mult): for q_idx, question in tqdm(enumerate(questions), total=len(questions), ncols=80, desc=f"{n1}-{n2}-{k1}-{k2}"): q_titles = question['top_titles'][:max(n1, n2)] title2encs = {} title2idx2par_name = {} for t2enc, t2id2p in workers.imap_unordered(get_title_mappings_from_saver, q_titles): title2encs.update(t2enc) title2idx2par_name.update(t2id2p) title2par_name2idxs = {} for title, id2par in title2idx2par_name.items(): par2idxs = {} for idx, parname in id2par.items(): if parname in par2idxs: par2idxs[parname].append(idx) else: par2idxs[parname] = [idx] title2par_name2idxs[title] = {par: sorted(idxs) for par, idxs in par2idxs.items()} title2ids = {} all_par_reps = [] total_sentences = 0 titles_offset_dict = {} for title in question['top_titles'][:n1]: titles_offset_dict[title] = total_sentences rep = title2encs[title] title2ids[title] = list(range(total_sentences, total_sentences + len(rep))) all_par_reps.append(rep) total_sentences += len(rep) id2title = {i: title for title, ids in title2ids.items() for i in ids} all_par_reps = np.concatenate(all_par_reps, axis=0) def id_to_par_name(rep_id): return title2idx2par_name[id2title[rep_id]][rep_id - titles_offset_dict[id2title[rep_id]]] q_enc = q_search_encodings[q_idx] top_k = simple_numpy_knn(np.expand_dims(q_enc, 0), all_par_reps, k1 * safety_mult)[0] seen = set() p_names = [id_to_par_name(x) for x in top_k if not (id_to_par_name(x) in seen or seen.add(id_to_par_name(x)))][:k1] iteration1_paragraphs = \ [title2encs[par_name_to_title(pname)][title2par_name2idxs[par_name_to_title(pname)][pname], :] for pname in p_names] if not reformulate_from_text: reformulations = encoder.reformulate_questions(questions_rep=np.tile(q_original_encodings[q_idx], reps=(len(p_names), 1)), paragraphs_rep=iteration1_paragraphs, return_search_vectors=True) else: tok_q = tokenize(question['question']).words() par_texts = [tokenize(parname_to_text(pname)).words() for pname in p_names] reformulations = encoder.reformulate_questions_from_texts( tokenized_questions=[tok_q for _ in range(len(par_texts))], tokenized_pars=par_texts, return_search_vectors=True ) title2ids = {} all_par_reps = [] total_sentences = 0 titles_offset_dict = {} for title in question['top_titles'][:n2]: titles_offset_dict[title] = total_sentences rep = title2encs[title] title2ids[title] = list(range(total_sentences, total_sentences + len(rep))) all_par_reps.append(rep) total_sentences += len(rep) id2title = {i: title for title, ids in title2ids.items() for i in ids} all_par_reps = np.concatenate(all_par_reps, axis=0) top_k_second = numpy_global_knn(reformulations, all_par_reps, k2 * safety_mult) seen = set() final_p_name_pairs = [(p_names[x1], id_to_par_name(x2)) for x1, x2 in top_k_second if not ((p_names[x1], id_to_par_name(x2)) in seen or seen.add((p_names[x1], id_to_par_name(x2))))][:k2] # important to note that in the iterative dataset the paragraphs of each question are in pairs question['paragraph_pairs'] = final_p_name_pairs return questions