Exemplo n.º 1
0
def eval_questions(questions, top_k=None, specific_ks=None):
    k_scores = {}
    bridge_k_scores = {}
    comparision_k_scores = {}
    for question in questions:
        q_stats = get_stats_for_sample(
            set([x[0] for x in question['supporting_facts']]),
            sorted_title_pairs=[[par_name_to_title(x) for x in pair]
                                for pair in question['paragraph_pairs']],
            top_k=top_k)
        type_scores = bridge_k_scores if question[
            'type'] == 'bridge' else comparision_k_scores
        for k in q_stats:
            for score_dict in [k_scores, type_scores]:
                if k not in score_dict:
                    score_dict[k] = {
                        key: [val]
                        for key, val in q_stats[k].items()
                    }
                else:
                    for key, val in q_stats[k].items():
                        score_dict[k][key].append(val)

    for score_dict in [k_scores, bridge_k_scores, comparision_k_scores]:
        for k in score_dict.keys():
            score_dict[k] = {
                key: np.mean(val)
                for key, val in score_dict[k].items()
            }

    for score_dict, name in [(bridge_k_scores, 'Bridge'),
                             (comparision_k_scores, 'Comparison'),
                             (k_scores, 'Overall')]:
        results = prettytable.PrettyTable(
            ['Top K Pairs', 'Hits', 'Perfect Questions', 'At Least One'])
        for k, k_dict in score_dict.items():
            results.add_row(
                [k, k_dict['hits'], k_dict['perfect'], k_dict['at_least_one']])
        results.sortby = 'Top K Pairs'

        print(f"{name} scores:")
        print(results)
        print('\n**********************************************\n')

    if specific_ks is not None:
        for score_dict, name in [(bridge_k_scores, 'Bridge'),
                                 (comparision_k_scores, 'Comparison'),
                                 (k_scores, 'Overall')]:
            results = prettytable.PrettyTable(
                ['Top K Pairs', 'Hits', 'Perfect Questions', 'At Least One'])
            for k, k_dict in [(k, v) for k, v in score_dict.items()
                              if k in specific_ks]:
                results.add_row([
                    k, k_dict['hits'], k_dict['perfect'],
                    k_dict['at_least_one']
                ])
            print(f"{name} scores:")
            print(results)
            print('\n**********************************************\n')
 def parname_to_text(par_name):
     par_title = par_name_to_title(par_name)
     par_num = int(par_name.split('_')[-1])
     return documents[par_title][par_num]
Exemplo n.º 3
0
def build_openqa_iterative_top_titles(out_file, questions_file, docs_file,
                                      encodings_dir, encoder_model, k1, k2, n1,
                                      n2, evaluate: bool,
                                      reformulate_from_text: bool,
                                      use_ema: bool, checkpoint: str):
    print('Loading data...')
    s = time.time()
    with open(questions_file, 'r') as f:
        questions = json.load(f)
    with open(docs_file, 'r') as f:
        documents = json.load(f)
    print(f'Done, took {time.time()-s} seconds.')

    if n1 is not None and n2 is not None:
        for q in questions:
            q['top_titles'] = q['top_titles'][:max(n1, n2)]

    # Setup worker pool
    workers = ProcessPool(16, initializer=init, initargs=[])

    qid2tokenized = {}
    tupled_questions = [(q['qid'], q['question']) for q in questions]
    print("Tokenizing questions...")
    with tqdm(total=len(tupled_questions)) as pbar:
        for tok_q in tqdm(
                workers.imap_unordered(tokenize_question, tupled_questions)):
            qid2tokenized.update(tok_q)
            pbar.update()

    voc = set()
    for question in qid2tokenized.values():
        voc.update(question)

    all_titles = list(
        set([title for q in questions for title in q['top_titles']]))

    def parname_to_text(par_name):
        par_title = par_name_to_title(par_name)
        par_num = int(par_name.split('_')[-1])
        return documents[par_title][par_num]

    print(f"Gathering documents...")
    # Setup worker pool
    workers = ProcessPool(32,
                          initializer=init_encoding_handler,
                          initargs=[encodings_dir])
    title2encs = {}
    title2idx2par_name = {}
    with tqdm(total=len(all_titles)) as pbar:
        for t2enc, t2id2p in tqdm(
                workers.imap_unordered(get_title_mappings_from_saver,
                                       all_titles)):
            title2encs.update(t2enc)
            title2idx2par_name.update(t2id2p)
            pbar.update()
    title2par_name2idxs = {}
    for title, id2par in title2idx2par_name.items():
        par2idxs = {}
        for idx, parname in id2par.items():
            if parname in par2idxs:
                par2idxs[parname].append(idx)
            else:
                par2idxs[parname] = [idx]
        title2par_name2idxs[title] = {
            par: sorted(idxs)
            for par, idxs in par2idxs.items()
        }

    print("Loading encoder...")
    spec = QuestionAndParagraphsSpec(batch_size=None,
                                     max_num_contexts=2,
                                     max_num_question_words=None,
                                     max_num_context_words=None)
    encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model,
                                            vocabulary=voc,
                                            spec=spec,
                                            loader=ResourceLoader(),
                                            use_char_inputs=False,
                                            use_ema=use_ema,
                                            checkpoint=checkpoint)

    print("Encoding questions...")
    q_original_encodings = encoder.encode_text_questions(
        [qid2tokenized[q['qid']] for q in questions],
        return_search_vectors=False,
        show_progress=True)
    q_search_encodings = encoder.question_rep_to_search_vector(
        question_encodings=q_original_encodings)

    init()  # for initializing the tokenizer

    print("Calculating similarities...")
    for idx, question in tqdm(enumerate(questions), total=len(questions)):
        title2ids = {}
        all_par_reps = []
        total_sentences = 0
        titles_offset_dict = {}
        for title in question['top_titles'][:n1]:
            titles_offset_dict[title] = total_sentences
            rep = title2encs[title]
            title2ids[title] = list(
                range(total_sentences, total_sentences + len(rep)))
            all_par_reps.append(rep)
            total_sentences += len(rep)
        id2title = {i: title for title, ids in title2ids.items() for i in ids}
        all_par_reps = np.concatenate(all_par_reps, axis=0)

        q_enc = q_search_encodings[idx]
        top_k = simple_numpy_knn(np.expand_dims(q_enc, 0), all_par_reps,
                                 k1 * 2)[0]

        def id_to_par_name(rep_id):
            return title2idx2par_name[id2title[rep_id]][
                rep_id - titles_offset_dict[id2title[rep_id]]]

        seen = set()
        p_names = [
            id_to_par_name(x) for x in top_k
            if not (id_to_par_name(x) in seen or seen.add(id_to_par_name(x)))
        ][:k1]
        iteration1_paragraphs = \
            [title2encs[par_name_to_title(pname)][title2par_name2idxs[par_name_to_title(pname)][pname], :]
             for pname in p_names]
        if not reformulate_from_text:
            reformulations = encoder.reformulate_questions(
                questions_rep=np.tile(q_original_encodings[idx],
                                      reps=(len(p_names), 1)),
                paragraphs_rep=iteration1_paragraphs,
                return_search_vectors=True)
        else:
            tok_q = tokenize(question['question']).words()
            par_texts = [
                tokenize(parname_to_text(pname)).words() for pname in p_names
            ]
            reformulations = encoder.reformulate_questions_from_texts(
                tokenized_questions=[tok_q for _ in range(len(par_texts))],
                tokenized_pars=par_texts,
                return_search_vectors=True)

        title2ids = {}
        all_par_reps = []
        total_sentences = 0
        titles_offset_dict = {}
        for title in question['top_titles'][:n2]:
            titles_offset_dict[title] = total_sentences
            rep = title2encs[title]
            title2ids[title] = list(
                range(total_sentences, total_sentences + len(rep)))
            all_par_reps.append(rep)
            total_sentences += len(rep)
        id2title = {i: title for title, ids in title2ids.items() for i in ids}
        all_par_reps = np.concatenate(all_par_reps, axis=0)

        top_k_second = numpy_global_knn(reformulations, all_par_reps, k2 * k1)
        seen = set()
        final_p_name_pairs = [
            (p_names[x1], id_to_par_name(x2)) for x1, x2 in top_k_second
            if not ((p_names[x1], id_to_par_name(x2)) in seen or seen.add(
                (p_names[x1], id_to_par_name(x2))))
        ][:k2]

        # important to note that in the iterative dataset the paragraphs of each question are in pairs
        question['paragraph_pairs'] = final_p_name_pairs

    with open(out_file, 'w') as f:
        json.dump(questions, f)

    if evaluate:
        eval_questions(questions)
Exemplo n.º 4
0
def reformulation_retrieval(encoder,
                            workers,
                            questions: List,
                            doc_db: DocDB,
                            k2: int,
                            n2: int,
                            safety_mult: int = 1):
    def title_to_text(title):
        return ' '.join(doc_db.get_doc_sentences(title))

    tokenized_qs = [
        tok_q.words()
        for tok_q in workers.imap(tokenize, [q['question'] for q in questions])
    ]
    par_texts = [
        x for x in workers.imap(tokenize_and_concat,
                                [[title_to_text(title) for title in titles]
                                 for q in questions
                                 for titles in q['top_pars_titles']])
    ]
    pnames_end_idxs = list(
        itertools.accumulate([len(q['top_pars_titles']) for q in questions]))
    q_with_p = list(
        zip([
            tokenized_qs[idx] for idx, q in enumerate(questions)
            for _ in q['top_pars_titles']
        ], par_texts))
    q_with_p = [(x, i) for i, x in enumerate(q_with_p)]
    sorted_q_with_p = sorted(q_with_p,
                             key=lambda x: (len(x[0][1]), len(x[0][0]), x[1]),
                             reverse=True)
    sorted_qs, sorted_ps = zip(*[x for x, _ in sorted_q_with_p])
    last_long_index = max(
        [i for i, x in enumerate(sorted_ps) if len(x) >= 900] + [-1])
    if last_long_index != -1:
        reformulations_long = encoder.reformulate_questions_from_texts(
            tokenized_questions=sorted_qs[:last_long_index + 1],
            tokenized_pars=sorted_ps[:last_long_index + 1],
            return_search_vectors=True,
            show_progress=False,
            max_batch=8)
        reformulations_short = encoder.reformulate_questions_from_texts(
            tokenized_questions=sorted_qs[last_long_index + 1:],
            tokenized_pars=sorted_ps[last_long_index + 1:],
            return_search_vectors=True,
            show_progress=False,
            max_batch=64)
        reformulations = np.concatenate(
            [reformulations_long, reformulations_short], axis=0)
    else:
        reformulations = encoder.reformulate_questions_from_texts(
            tokenized_questions=sorted_qs,
            tokenized_pars=sorted_ps,
            return_search_vectors=True,
            show_progress=False,
            max_batch=64)
    reformulations = reformulations[np.argsort([i
                                                for _, i in sorted_q_with_p])]
    pnames_end_idxs = [0] + pnames_end_idxs
    reformulations_per_question = [
        reformulations[pnames_end_idxs[i]:pnames_end_idxs[i + 1]]
        for i in range(len(questions))
    ]
    for q_idx, question in enumerate(questions):
        q_titles = question['top_titles'][:n2]
        title2encs = {}
        title2idx2par_name = {}
        for t2enc, t2id2p in workers.imap_unordered(
                get_title_mappings_from_saver, q_titles):
            title2encs.update(t2enc)
            title2idx2par_name.update(t2id2p)
        title2par_name2idxs = {}
        for title, id2par in title2idx2par_name.items():
            par2idxs = {}
            for idx, parname in id2par.items():
                if parname in par2idxs:
                    par2idxs[parname].append(idx)
                else:
                    par2idxs[parname] = [idx]
            title2par_name2idxs[title] = {
                par: sorted(idxs)
                for par, idxs in par2idxs.items()
            }
        title2ids = {}
        all_par_reps = []
        total_sentences = 0
        titles_offset_dict = {}
        for title in question['top_titles'][:n2]:
            titles_offset_dict[title] = total_sentences
            rep = title2encs[title]
            title2ids[title] = list(
                range(total_sentences, total_sentences + len(rep)))
            all_par_reps.append(rep)
            total_sentences += len(rep)
        id2title = {i: title for title, ids in title2ids.items() for i in ids}
        all_par_reps = np.concatenate(all_par_reps, axis=0)

        def id_to_par_name(rep_id):
            return title2idx2par_name[id2title[rep_id]][
                rep_id - titles_offset_dict[id2title[rep_id]]]

        top_k_second = numpy_global_knn(reformulations_per_question[q_idx],
                                        all_par_reps, k2 * safety_mult)
        seen = set()
        p_names = question['top_pars_titles']
        final_p_name_pairs = [
            (*p_names[x1], par_name_to_title(id_to_par_name(x2)))
            for x1, x2 in top_k_second if not (
                (*p_names[x1],
                 par_name_to_title(id_to_par_name(x2))) in seen or seen.add(
                     (*p_names[x1], par_name_to_title(id_to_par_name(x2)))))
        ][:k2]

        # important to note that in the iterative dataset the paragraphs of each question are in pairs
        question['top_pars_titles'] = final_p_name_pairs
    return questions
Exemplo n.º 5
0
def initial_retrieval(encoder,
                      workers,
                      questions: List,
                      k1: int,
                      n1: int,
                      safety_mult: int = 1):
    tokenized_qs = [
        tok_q.words()
        for tok_q in workers.imap(tokenize, [q['question'] for q in questions])
    ]
    q_search_encodings = encoder.encode_text_questions(
        tokenized_qs, return_search_vectors=True, show_progress=False)
    for q_idx, question in enumerate(questions):
        q_titles = question['top_titles'][:n1]
        title2encs = {}
        title2idx2par_name = {}
        for t2enc, t2id2p in workers.imap_unordered(
                get_title_mappings_from_saver, q_titles):
            title2encs.update(t2enc)
            title2idx2par_name.update(t2id2p)
        title2par_name2idxs = {}
        for title, id2par in title2idx2par_name.items():
            par2idxs = {}
            for idx, parname in id2par.items():
                if parname in par2idxs:
                    par2idxs[parname].append(idx)
                else:
                    par2idxs[parname] = [idx]
            title2par_name2idxs[title] = {
                par: sorted(idxs)
                for par, idxs in par2idxs.items()
            }
        title2ids = {}
        all_par_reps = []
        total_sentences = 0
        titles_offset_dict = {}
        for title in question['top_titles'][:n1]:
            titles_offset_dict[title] = total_sentences
            rep = title2encs[title]
            title2ids[title] = list(
                range(total_sentences, total_sentences + len(rep)))
            all_par_reps.append(rep)
            total_sentences += len(rep)
        id2title = {i: title for title, ids in title2ids.items() for i in ids}
        all_par_reps = np.concatenate(all_par_reps, axis=0)

        def id_to_par_name(rep_id):
            return title2idx2par_name[id2title[rep_id]][
                rep_id - titles_offset_dict[id2title[rep_id]]]

        q_enc = q_search_encodings[q_idx]
        top_k = simple_numpy_knn(np.expand_dims(q_enc, 0), all_par_reps,
                                 k1 * safety_mult)[0]

        seen = set()
        p_names = [
            id_to_par_name(x) for x in top_k
            if not (id_to_par_name(x) in seen or seen.add(id_to_par_name(x)))
        ][:k1]
        question['top_pars_titles'] = [(par_name_to_title(p), )
                                       for p in p_names]
    return questions
Exemplo n.º 6
0
 def parname_to_text(par_name):
     par_title = par_name_to_title(par_name)
     par_num = int(par_name.split('_')[-1])
     if docs_file is not None:
         return documents[par_title][par_num]
     return ' '.join(docs_db.get_doc_sentences(par_title))
Exemplo n.º 7
0
 def parname_to_text(par_name):
     par_title = par_name_to_title(par_name)
     par_num = int(par_name.split('_')[-1])
     if args.docs_file is None:
         return doc_db.get_doc_sentences(par_title)
     return docs[par_title][par_num]
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument(
        'output',
        type=str,
        help="Store the per-paragraph results in csv format in this file, "
        "or the json prediction if in test mode")
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    parser.add_argument('-a',
                        '--answer_bound',
                        type=int,
                        default=8,
                        help="Max answer span length")
    parser.add_argument('-c',
                        '--corpus',
                        choices=[
                            "distractors", "gold", "hotpot_file",
                            "retrieval_file", "top_titles"
                        ],
                        default="distractors")
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=None,
                        help="Max tokens per a paragraph")
    parser.add_argument('--input_file', type=str, default=None)
    parser.add_argument('--docs_file', type=str, default=None)
    parser.add_argument('--num_workers',
                        type=int,
                        default=16,
                        help='Number of workers for tokenizing')
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    parser.add_argument('--no_sp',
                        action="store_true",
                        help="Don't predict supporting facts")
    parser.add_argument('--test_mode',
                        action='store_true',
                        help="produce a prediction file, no answers given")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)
    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    loader = ResourceLoader()

    if args.corpus not in {"distractors", "gold"} and args.input_file is None:
        raise ValueError(
            "Must pass an input file if not using precomputed dataset")

    if args.corpus in {"distractors", "gold"} and args.test_mode:
        raise ValueError(
            "Test mode not available in 'distractors' or 'gold' mode")

    if args.corpus in {"distractors", "gold"}:

        corpus = HotpotQuestions()
        loader = corpus.get_resource_loader()
        questions = corpus.get_dev()

        question_preprocessor = HotpotTextLengthPreprocessorWithSpans(
            args.tokens)
        questions = [
            question_preprocessor.preprocess(x) for x in questions
            if (question_preprocessor.preprocess(x) is not None)
        ]

        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(questions, key=lambda x: x.question_id))
            questions = questions[:args.sample_questions]

        data = HotpotFullQADistractorsDataset(questions, batcher)
        gold_idxs = set(data.gold_idxs)
        if args.corpus == 'gold':
            data.samples = [data.samples[i] for i in data.gold_idxs]
        qid2samples = {}
        qid2idx = {}
        for i, sample in enumerate(data.samples):
            key = sample.question_id
            if key in qid2samples:
                qid2samples[key].append(sample)
                qid2idx[key].append(i)
            else:
                qid2samples[key] = [sample]
                qid2idx[key] = [i]
        questions = []
        print("Ranking pairs...")
        gold_ranks = []
        for qid, samples in tqdm(qid2samples.items()):
            question = " ".join(samples[0].question)
            pars = [" ".join(x.paragraphs[0]) for x in samples]
            ranks = get_paragraph_ranks(question, pars)
            for sample, rank, idx in zip(samples, ranks, qid2idx[qid]):
                questions.append(
                    RankedQAPair(question=sample.question,
                                 paragraphs=sample.paragraphs,
                                 spans=np.zeros((0, 2), dtype=np.int32),
                                 question_id=sample.question_id,
                                 answer=sample.answer,
                                 rank=rank,
                                 q_type=sample.q_type,
                                 sentence_segments=sample.sentence_segments))
                if idx in gold_idxs:
                    gold_ranks.append(rank + 1)
        print(f"Mean rank: {np.mean(gold_ranks)}")
        ranks_counter = Counter(gold_ranks)
        for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
            print(f"Hits at {i}: {ranks_counter[i]}")

    elif args.corpus == 'hotpot_file':  # a hotpot json format input file. We rank the pairs with tf-idf
        with open(args.input_file, 'r') as f:
            hotpot_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(hotpot_data, key=lambda x: x['_id']))
            hotpot_data = hotpot_data[:args.sample_questions]
        title2sentences = {
            context[0]: context[1]
            for q in hotpot_data for context in q['context']
        }
        question_tok_texts = tokenize_texts(
            [q['question'] for q in hotpot_data], num_workers=args.num_workers)
        sentences_tok = tokenize_texts(list(title2sentences.values()),
                                       num_workers=args.num_workers,
                                       sentences=True)
        if args.tokens is not None:
            sentences_tok = [
                truncate_paragraph(p, args.tokens) for p in sentences_tok
            ]
        title2tok_sents = {
            title: sentences
            for title, sentences in zip(title2sentences.keys(), sentences_tok)
        }
        questions = []
        for idx, question in enumerate(tqdm(hotpot_data,
                                            desc='tf-idf ranking')):
            q_titles = [title for title, _ in question['context']]
            par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles)
                         for title2 in q_titles[i + 1:]]
            if len(par_pairs) == 0:
                continue
            ranks = get_paragraph_ranks(question['question'], [
                ' '.join(title2sentences[t1] + title2sentences[t2])
                for t1, t2 in par_pairs
            ])
            for rank, par_pair in zip(ranks, par_pairs):
                sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    sent_tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(title2tok_sents[title])
                    if len(sent) == 0
                ] for title in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(sent_tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['_id'],
                        answer='noanswer'
                        if args.test_mode else question['answer'],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (title,
                             sum(1 for sent in title2tok_sents[title]
                                 if len(sent) > 0)) for title in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles':
        if args.docs_file is None:
            print("Using DB documents")
            doc_db = DocDB(config.DOC_DB, full_docs=False)
        else:
            with open(args.docs_file, 'r') as f:
                docs = json.load(f)
        with open(args.input_file, 'r') as f:
            retrieval_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(retrieval_data, key=lambda x: x['qid']))
            retrieval_data = retrieval_data[:args.sample_questions]

        def parname_to_text(par_name):
            par_title = par_name_to_title(par_name)
            par_num = int(par_name.split('_')[-1])
            if args.docs_file is None:
                return doc_db.get_doc_sentences(par_title)
            return docs[par_title][par_num]

        if args.corpus == 'top_titles':
            print("Top TF-IDF!")
            for q in retrieval_data:
                top_titles = q['top_titles'][:10]
                q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0')
                                        for i, title1 in enumerate(top_titles)
                                        for title2 in top_titles[i + 1:]]

        question_tok_texts = tokenize_texts(
            [q['question'] for q in retrieval_data],
            num_workers=args.num_workers)
        all_parnames = list(
            set([
                parname for q in retrieval_data
                for pair in q['paragraph_pairs'] for parname in pair
            ]))
        texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames],
                                   num_workers=args.num_workers,
                                   sentences=True)
        if args.tokens is not None:
            texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok]
        parname2tok_text = {
            parname: text
            for parname, text in zip(all_parnames, texts_tok)
        }
        questions = []
        for idx, question in enumerate(retrieval_data):
            for rank, par_pair in enumerate(question['paragraph_pairs']):
                tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(parname2tok_text[parname])
                    if len(sent) == 0
                ] for parname in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['qid'],
                        answer='noanswer'
                        if args.test_mode else question['answers'][0],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (par_name_to_title(parname),
                             sum(1 for sent in parname2tok_text[parname]
                                 if len(sent) > 0)) for parname in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    else:
        raise NotImplementedError()

    data = DummyDataset(questions, batcher)
    evaluators = [
        RecordHotpotQAPrediction(args.answer_bound,
                                 True,
                                 sp_prediction=not args.no_sp)
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader,
                              checkpoint, not args.no_ema, 10)[args.corpus]

    print("Saving result")
    output_file = args.output

    df = pd.DataFrame(evaluation.per_sample)

    df.sort_values(["question_id", "rank"], inplace=True, ascending=True)
    group_by = ["question_id"]

    def get_ranked_scores(score_name):
        filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \
            df[df.type == 'bridge'] if "Br" in score_name else df
        target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text'
        target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}"
        return compute_ranked_scores_with_yes_no(
            filtered_df,
            span_q_col="span_question_scores",
            yes_no_q_col="yes_no_question_scores",
            yes_no_scores_col="yes_no_confidence_scores",
            span_scores_col="predicted_score",
            span_target_score=target_score,
            group_cols=group_by)

    if not args.test_mode:
        score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"]
        if not args.no_sp:
            score_names.extend([
                f"{prefix} {name}" for prefix in ['sp', 'joint']
                for name in score_names
            ])

        table = [["N Paragraphs"] + score_names]
        scores = [get_ranked_scores(score_name) for score_name in score_names]
        table += list([str(i + 1), *["%.4f" % x for x in score_vals]]
                      for i, score_vals in enumerate(zip(*scores)))
        print_table(table)

        df.to_csv(output_file, index=False)

    else:
        df_to_pred(df, output_file)
def iterative_retrieval(encoder, questions, q_original_encodings, q_search_encodings, workers,
                        parname_to_text, reformulate_from_text, n1, n2, k1, k2, safety_mult):
    for q_idx, question in tqdm(enumerate(questions), total=len(questions), ncols=80, desc=f"{n1}-{n2}-{k1}-{k2}"):
        q_titles = question['top_titles'][:max(n1, n2)]
        title2encs = {}
        title2idx2par_name = {}
        for t2enc, t2id2p in workers.imap_unordered(get_title_mappings_from_saver, q_titles):
            title2encs.update(t2enc)
            title2idx2par_name.update(t2id2p)
        title2par_name2idxs = {}
        for title, id2par in title2idx2par_name.items():
            par2idxs = {}
            for idx, parname in id2par.items():
                if parname in par2idxs:
                    par2idxs[parname].append(idx)
                else:
                    par2idxs[parname] = [idx]
            title2par_name2idxs[title] = {par: sorted(idxs) for par, idxs in par2idxs.items()}
        title2ids = {}
        all_par_reps = []
        total_sentences = 0
        titles_offset_dict = {}
        for title in question['top_titles'][:n1]:
            titles_offset_dict[title] = total_sentences
            rep = title2encs[title]
            title2ids[title] = list(range(total_sentences, total_sentences + len(rep)))
            all_par_reps.append(rep)
            total_sentences += len(rep)
        id2title = {i: title for title, ids in title2ids.items() for i in ids}
        all_par_reps = np.concatenate(all_par_reps, axis=0)

        def id_to_par_name(rep_id):
            return title2idx2par_name[id2title[rep_id]][rep_id - titles_offset_dict[id2title[rep_id]]]

        q_enc = q_search_encodings[q_idx]
        top_k = simple_numpy_knn(np.expand_dims(q_enc, 0), all_par_reps, k1 * safety_mult)[0]

        seen = set()
        p_names = [id_to_par_name(x)
                   for x in top_k if not (id_to_par_name(x) in seen or seen.add(id_to_par_name(x)))][:k1]
        iteration1_paragraphs = \
            [title2encs[par_name_to_title(pname)][title2par_name2idxs[par_name_to_title(pname)][pname], :]
             for pname in p_names]
        if not reformulate_from_text:
            reformulations = encoder.reformulate_questions(questions_rep=np.tile(q_original_encodings[q_idx],
                                                                                 reps=(len(p_names), 1)),
                                                           paragraphs_rep=iteration1_paragraphs,
                                                           return_search_vectors=True)
        else:
            tok_q = tokenize(question['question']).words()
            par_texts = [tokenize(parname_to_text(pname)).words() for pname in p_names]
            reformulations = encoder.reformulate_questions_from_texts(
                tokenized_questions=[tok_q for _ in range(len(par_texts))],
                tokenized_pars=par_texts,
                return_search_vectors=True
            )

        title2ids = {}
        all_par_reps = []
        total_sentences = 0
        titles_offset_dict = {}
        for title in question['top_titles'][:n2]:
            titles_offset_dict[title] = total_sentences
            rep = title2encs[title]
            title2ids[title] = list(range(total_sentences, total_sentences + len(rep)))
            all_par_reps.append(rep)
            total_sentences += len(rep)
        id2title = {i: title for title, ids in title2ids.items() for i in ids}
        all_par_reps = np.concatenate(all_par_reps, axis=0)

        top_k_second = numpy_global_knn(reformulations, all_par_reps, k2 * safety_mult)
        seen = set()
        final_p_name_pairs = [(p_names[x1], id_to_par_name(x2))
                              for x1, x2 in top_k_second
                              if not ((p_names[x1], id_to_par_name(x2)) in seen
                                      or seen.add((p_names[x1], id_to_par_name(x2))))][:k2]

        # important to note that in the iterative dataset the paragraphs of each question are in pairs
        question['paragraph_pairs'] = final_p_name_pairs
    return questions