示例#1
0
def get_all_word_count(num_workers, db_path, out_path, titles_only=False):
    db = DocDB(db_path)
    all_titles = db.get_doc_titles()

    # Setup worker pool
    workers = ProcessPool(
        num_workers,
        initializer=init,
        initargs=[db_path]
    )

    word_counts = Counter()

    if not titles_only:
        with tqdm(total=len(all_titles), desc='word count') as pbar:
            for w_count in tqdm(workers.imap_unordered(get_word_count, all_titles)):
                word_counts.update(w_count)
                pbar.update()
    else:
        with tqdm(total=len(all_titles), desc='title word count') as pbar:
            for tok_title in tqdm(workers.imap_unordered(tokenize, all_titles)):
                word_counts.update(tok_title.words())
                pbar.update()

    with open(out_path, 'w') as f:
        for k, v in tqdm(word_counts.most_common(), desc='writing counts'):
            f.write(f"{k}\t{v}\n")
def init(db_path, full_doc_db_path, ranker_path):
    global PROCESS_TOK, PROCESS_DB, PROCESS_RANKER, PROCESS_FULL_DOC_DB, DOC_TITLES
    PROCESS_TOK = CoreNLPTokenizer()
    Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
    PROCESS_DB = DocDB(db_path)
    Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
    PROCESS_FULL_DOC_DB = DocDB(full_doc_db_path, full_docs=True)
    Finalize(PROCESS_FULL_DOC_DB, PROCESS_FULL_DOC_DB.close, exitpriority=100)
    PROCESS_RANKER = TfidfDocRanker(ranker_path)
    DOC_TITLES = PROCESS_FULL_DOC_DB.get_doc_titles()
def init(top_k, get_text):
    global PROCESS_DB, PROCESS_RANKER, TOP_K, GET_TEXT
    PROCESS_DB = DocDB()
    Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
    PROCESS_RANKER = TfidfDocRanker()
    TOP_K = top_k
    GET_TEXT = get_text
def main():  # todo: add scores to questions & paragraphs
    parser = argparse.ArgumentParser("Preprocess Hotpot Questions")
    parser.add_argument("--train_file", default=config.HOTPOT_TRAIN_FILE)
    parser.add_argument("--dev_file", default=config.HOTPOT_DEV_DISTRACTOR_FILE)
    parser.add_argument("--doc_db", default=None)
    parser.add_argument('--num-workers', type=int, default=1, help='Number of CPU processes')

    if not exists(join(config.CORPUS_DIR, 'hotpot')):
        mkdir(join(config.CORPUS_DIR, 'hotpot'))

    args = parser.parse_args()

    # target_dir = config.CORPUS_DIR
    # if exists(target_dir) and len(listdir(target_dir)) > 0:
    #     raise ValueError("Files already exist in " + target_dir)

    if args.num_workers > 1:
        print(f"Multiprocessing with {args.num_workers} threads...")
        print("Parsing train...")
        train = parse_hotpot_data_async(args.train_file, args.doc_db, args.num_workers)

        print("Parsing dev...")
        dev = parse_hotpot_data_async(args.dev_file, args.doc_db, args.num_workers)

    else:
        tokenzier = CoreNLPTokenizer()
        docdb = DocDB(args.doc_db)

        print("Parsing train...")
        train = parse_hotpot_data(args.train_file, tokenzier, docdb)

        print("Parsing dev...")
        dev = parse_hotpot_data(args.dev_file, tokenzier, docdb)

    print("Saving...")
    HotpotQuestions.make_corpus(train, dev)
    print("Done")
示例#5
0
def get_count_matrix():
    """Form a sparse word to document count matrix (inverted index).

    M[i, j] = # times word i appears in document j.
    """
    # Map doc_ids to indexes
    global DOC2IDX
    with DocDB(args.db_path) as doc_db:
        doc_titles = doc_db.get_doc_titles()
    DOC2IDX = {doc_title: i for i, doc_title in enumerate(doc_titles)}

    # Setup worker pool
    workers = ProcessPool(args.num_workers,
                          initializer=init,
                          initargs=[args.db_path])

    # Compute the count matrix in steps (to keep in memory)
    logger.info('Mapping...')
    row, col, data = [], [], []
    step = max(int(len(doc_titles) / 10), 1)
    batches = [doc_titles[i:i + step] for i in range(0, len(doc_titles), step)]
    _count = partial(count, args.ngram, args.hash_size)
    for i, batch in enumerate(batches):
        logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) +
                    '-' * 25)
        for b_row, b_col, b_data in workers.imap_unordered(_count, batch):
            row.extend(b_row)
            col.extend(b_col)
            data.extend(b_data)
    workers.close()
    workers.join()

    logger.info('Creating sparse matrix...')
    count_matrix = sp.csr_matrix((data, (row, col)),
                                 shape=(args.hash_size, len(doc_titles)))
    count_matrix.sum_duplicates()
    return count_matrix, (DOC2IDX, doc_titles)
示例#6
0
def init():
    global PROCESS_TOK, PROCESS_DB
    PROCESS_TOK = CoreNLPTokenizer()
    Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
    PROCESS_DB = DocDB()
    Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100)
示例#7
0
def encode_from_file(docs_file,
                     questions_file,
                     encodings_dir,
                     encoder_model,
                     num_workers,
                     hotpot: bool,
                     long_batch: int,
                     short_batch: int,
                     use_chars: bool,
                     use_ema: bool,
                     checkpoint: str,
                     document_chunk_size=1000,
                     samples=None,
                     encode_all_db=False):
    """

    :param out_file: .npz file to dump the encodings
    :param docs_file: path to json file whose structure is [{title: list of paragraphs}, ...]
    :return:
    """
    doc_encs_handler = DocumentEncodingHandler(encodings_dir)
    # Setup worker pool
    workers = ProcessPool(num_workers, initializer=init, initargs=[])

    if docs_file is not None:
        with open(docs_file, 'r') as f:
            documents = json.load(f)
        documents = {
            k: v
            for k, v in documents.items()
            if k not in doc_encs_handler.titles2filenames
        }

        tokenized_documents = {}
        tupled_doc_list = [(title, pars) for title, pars in documents.items()]

        if samples is not None:
            print(f"sampling {samples} samples")
            tupled_doc_list = tupled_doc_list[:samples]

        print("Tokenizing from file...")
        with tqdm(total=len(tupled_doc_list), ncols=80) as pbar:
            for tok_doc in tqdm(
                    workers.imap_unordered(tokenize_document,
                                           tupled_doc_list)):
                tokenized_documents.update(tok_doc)
                pbar.update()
    else:
        if questions_file is not None:
            with open(questions_file, 'r') as f:
                questions = json.load(f)
            all_titles = list(
                set([title for q in questions for title in q['top_titles']]))
        else:
            print("encoding all DB!")
            all_titles = DocDB().get_doc_titles()

        if samples is not None:
            print(f"sampling {samples} samples")
            all_titles = all_titles[:samples]

        all_titles = [
            t for t in all_titles if t not in doc_encs_handler.titles2filenames
        ]
        tokenized_documents = {}

        print("Tokenizing from DB...")
        with tqdm(total=len(all_titles), ncols=80) as pbar:
            for tok_doc in tqdm(
                    workers.imap_unordered(tokenize_from_db, all_titles)):
                tokenized_documents.update(tok_doc)
                pbar.update()

    workers.close()
    workers.join()

    voc = set()
    for paragraphs in tokenized_documents.values():
        for par in paragraphs:
            voc.update(par)

    if not hotpot:
        spec = QuestionAndParagraphsSpec(batch_size=None,
                                         max_num_contexts=1,
                                         max_num_question_words=None,
                                         max_num_context_words=None)
        encoder = SentenceEncoderSingleContext(model_dir_path=encoder_model,
                                               vocabulary=voc,
                                               spec=spec,
                                               loader=ResourceLoader(),
                                               use_char_inputs=use_chars,
                                               use_ema=use_ema,
                                               checkpoint=checkpoint)
    else:
        spec = QuestionAndParagraphsSpec(batch_size=None,
                                         max_num_contexts=2,
                                         max_num_question_words=None,
                                         max_num_context_words=None)
        encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model,
                                                vocabulary=voc,
                                                spec=spec,
                                                loader=ResourceLoader(),
                                                use_char_inputs=use_chars,
                                                use_ema=use_ema,
                                                checkpoint=checkpoint)

    tokenized_documents_items = list(tokenized_documents.items())
    for tokenized_doc_chunk in tqdm([
            tokenized_documents_items[i:i + document_chunk_size] for i in
            range(0, len(tokenized_documents_items), document_chunk_size)
    ],
                                    ncols=80):
        flattened_pars_with_names = [(f"{title}_{i}", par)
                                     for title, pars in tokenized_doc_chunk
                                     for i, par in enumerate(pars)]

        # filtering out empty paragraphs (probably had some short string the tokenization removed)
        # important to notice that the filtered paragraphs will have no representation,
        # but they still exist in the numbering of paragraphs for consistency with the docs.
        flattened_pars_with_names = [(name, par)
                                     for name, par in flattened_pars_with_names
                                     if len(par) > 0]

        # sort such that longer paragraphs are first to identify OOMs early on
        flattened_pars_with_names = sorted(flattened_pars_with_names,
                                           key=lambda x: len(x[1]),
                                           reverse=True)
        long_paragraphs_ids = [
            i for i, name_par in enumerate(flattened_pars_with_names)
            if len(name_par[1]) >= 900
        ]
        short_paragraphs_ids = [
            i for i, name_par in enumerate(flattened_pars_with_names)
            if len(name_par[1]) < 900
        ]

        # print(f"Encoding {len(flattened_pars_with_names)} paragraphs...")
        name2enc = {}
        dummy_question = "Hello Hello".split()
        if not hotpot:
            model_paragraphs = [
                BinaryQuestionAndParagraphs(question=dummy_question,
                                            paragraphs=[x],
                                            label=1,
                                            num_distractors=0,
                                            question_id='dummy')
                for _, x in flattened_pars_with_names
            ]
        else:
            # todo allow precomputed sentence segments
            model_paragraphs = [
                IterativeQuestionAndParagraphs(question=dummy_question,
                                               paragraphs=[x, dummy_question],
                                               first_label=1,
                                               second_label=1,
                                               question_id='dummy',
                                               sentence_segments=None)
                for _, x in flattened_pars_with_names
            ]

        # print("Encoding long paragraphs...")
        long_pars = [model_paragraphs[i] for i in long_paragraphs_ids]
        name2enc.update({
            flattened_pars_with_names[long_paragraphs_ids[i]][0]: enc
            for i, enc in enumerate(
                encoder.encode_paragraphs(
                    long_pars, batch_size=long_batch, show_progress=True
                ) if not hotpot else encoder.encode_first_paragraphs(
                    long_pars, batch_size=long_batch, show_progress=True))
        })

        # print("Encoding short paragraphs...")
        short_pars = [model_paragraphs[i] for i in short_paragraphs_ids]
        name2enc.update({
            flattened_pars_with_names[short_paragraphs_ids[i]][0]: enc
            for i, enc in enumerate(
                encoder.encode_paragraphs(
                    short_pars, batch_size=short_batch, show_progress=True
                ) if not hotpot else encoder.encode_first_paragraphs(
                    short_pars, batch_size=short_batch, show_progress=True))
        })

        doc_encs_handler.save_multiple_documents(name2enc)
示例#8
0
    Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)


def tokenize_words(text):
    global PROCESS_TOK
    return PROCESS_TOK.tokenize(text).words()


def tokenize_sentences(sentences):
    global PROCESS_TOK
    return [PROCESS_TOK.tokenize(s).words() if s != '' else [] for s in sentences]


print("Loading TF-IDF...")
tfidf_ranker = TfidfDocRanker()
db = DocDB()

loader = ResourceLoader()
# loader = HotpotQuestions().get_resource_loader()
word_counts = load_counts(join(LOCAL_DATA_DIR, 'hotpot', 'wiki_word_counts.txt'))
title_counts = load_counts(join(LOCAL_DATA_DIR, 'hotpot', 'wiki_title_word_counts.txt'))
word_counts.update(title_counts)
voc = set(word_counts.keys())

print("Loading encoder...")

spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2,
                                 max_num_question_words=None, max_num_context_words=None)
encoder = SentenceEncoderIterativeModel(model_dir_path=args.encoder_model, vocabulary=voc,
                                        spec=spec, loader=loader, use_char_inputs=False,
                                        use_ema=not args.no_ema,
示例#9
0
文件: eval.py 项目: sjliu0920/MUPPET
def modified_main(dataset_path,
                  k_list_to_check,
                  ranker_path=None,
                  normalize_ranker=False,
                  num_workers=1,
                  tokenizer='corenlp',
                  docdb_path=None,
                  out=None):
    dataset = load_dataset(dataset_path)
    ranker = TfidfDocRanker(tfidf_path=ranker_path,
                            normalize_vectors=normalize_ranker,
                            tokenizer=tokenizer)
    docdb = DocDB(docdb_path)
    print("Building modified queries...")
    ranked_gold_dict = build_ranked_golds(dataset, docdb=docdb, ranker=ranker)
    regular_table = prettytable.PrettyTable([
        'Top K', 'Second Paragraph Hits',
        'Second Paragraph Hits Modified Query'
    ])
    cat_table_dict = {
        cat: prettytable.PrettyTable([
            'Top K', 'Second Paragraph Hits',
            'Second Paragraph Hits Modified Query'
        ])
        for cat in CATEGORIES
    }
    max_k = max(k_list_to_check)
    print(f"Retrieving top {max_k} ...")
    start = time.time()
    reg_result_dict, ranked_result_dict = get_ranked_top_k(
        dataset, ranked_gold_dict, ranker, max_k, num_workers)
    print(f"Done, took {time.time()-start} ms.")
    for k in k_list_to_check:
        print(f"Calculating scores for top {k}...")
        start = time.time()
        reg_scores, reg_category_scores = modified_top_k_coverage_score(
            ranked_gold_dict, reg_result_dict, k)
        mod_scores, mod_category_scores = modified_top_k_coverage_score(
            ranked_gold_dict, ranked_result_dict, k)
        print(f"Done, took {time.time()-start} ms.")
        regular_table.add_row([
            k, reg_scores['Second Paragraph Hits'],
            mod_scores['Second Paragraph Hits']
        ])
        for cat in cat_table_dict:
            cat_table_dict[cat].add_row([
                k, reg_category_scores[cat]['Second Paragraph Hits'],
                mod_category_scores[cat]['Second Paragraph Hits']
            ])
    output_str = 'Overall Results:\n'
    output_str += regular_table.__str__() + '\n'
    for cat, table in cat_table_dict.items():
        output_str += '\n**********************************************\n'
        output_str += f"Category: {cat} Results:\n"
        output_str += table.__str__() + '\n'

    if out is None:
        print(output_str)
    else:
        with open(out, 'w') as f:
            f.write(output_str)
示例#10
0
def build_openqa_iterative_top_titles(
        base_dir, questions_file, docs_file, encodings_dir, encoder_model,
        k1_list: List[int], k2_list: List[int], n1_list: List[int],
        n2_list: List[int], evaluate: bool, reformulate_from_text: bool,
        use_ema: bool, checkpoint: str, safety_mult: int):
    print('Loading data...')
    s = time.time()
    with open(questions_file, 'r') as f:
        questions = json.load(f)
    if docs_file is not None:
        with open(docs_file, 'r') as f:
            documents = json.load(f)
    else:
        docs_db = DocDB()
    print(f'Done, took {time.time()-s} seconds.')

    if n1_list is not None and n2_list is not None:
        for q in questions:
            q['top_titles'] = q['top_titles'][:max(max(n1_list), max(n2_list))]

    # Setup worker pool
    workers = ProcessPool(16, initializer=init, initargs=[])

    qid2tokenized = {}
    tupled_questions = [(q['qid'], q['question']) for q in questions]
    print("Tokenizing questions...")
    with tqdm(total=len(tupled_questions)) as pbar:
        for tok_q in tqdm(
                workers.imap_unordered(tokenize_question, tupled_questions)):
            qid2tokenized.update(tok_q)
            pbar.update()

    voc = set()
    for question in qid2tokenized.values():
        voc.update(question)

    workers.close()
    workers.join()

    # all_titles = list(set([title for q in questions for title in q['top_titles']]))

    def parname_to_text(par_name):
        par_title = par_name_to_title(par_name)
        par_num = int(par_name.split('_')[-1])
        if docs_file is not None:
            return documents[par_title][par_num]
        return ' '.join(docs_db.get_doc_sentences(par_title))

    # print(f"Gathering documents...")
    # Setup worker pool
    workers = ProcessPool(16,
                          initializer=init_encoding_handler,
                          initargs=[encodings_dir])
    # title2encs = {}
    # title2idx2par_name = {}
    # with tqdm(total=len(all_titles)) as pbar:
    #     for t2enc, t2id2p in tqdm(workers.imap_unordered(get_title_mappings_from_saver, all_titles)):
    #         title2encs.update(t2enc)
    #         title2idx2par_name.update(t2id2p)
    #         pbar.update()
    # title2par_name2idxs = {}
    # for title, id2par in title2idx2par_name.items():
    #     par2idxs = {}
    #     for idx, parname in id2par.items():
    #         if parname in par2idxs:
    #             par2idxs[parname].append(idx)
    #         else:
    #             par2idxs[parname] = [idx]
    #     title2par_name2idxs[title] = {par: sorted(idxs) for par, idxs in par2idxs.items()}

    print("Loading encoder...")
    spec = QuestionAndParagraphsSpec(batch_size=None,
                                     max_num_contexts=2,
                                     max_num_question_words=None,
                                     max_num_context_words=None)
    encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model,
                                            vocabulary=voc,
                                            spec=spec,
                                            loader=ResourceLoader(),
                                            use_char_inputs=False,
                                            use_ema=use_ema,
                                            checkpoint=checkpoint)

    print("Encoding questions...")
    q_original_encodings = encoder.encode_text_questions(
        [qid2tokenized[q['qid']] for q in questions],
        return_search_vectors=False,
        show_progress=True)
    q_search_encodings = encoder.question_rep_to_search_vector(
        question_encodings=q_original_encodings)

    init()  # for initializing the tokenizer

    total_num = len(n1_list) * len(n2_list) * len(k1_list) * len(k2_list)
    print("Calculating similarities...")
    for n1, n2, k1, k2 in tqdm(itertools.product(n1_list, n2_list, k1_list,
                                                 k2_list),
                               total=total_num,
                               ncols=80):
        questions = iterative_retrieval(encoder, questions, qid2tokenized,
                                        q_search_encodings, workers,
                                        parname_to_text, reformulate_from_text,
                                        n1, n2, k1, k2, safety_mult)
        dir_path = os.path.join(base_dir, f"n2-{n2}", f"n1-{n1}")
        os.makedirs(dir_path, exist_ok=True)
        out_file = os.path.join(dir_path,
                                f"n1-{n1}_n2-{n2}_k1-{k1}_k2-{k2}.json")
        questions_copy = deepcopy(questions)
        for question in questions_copy:
            question.pop('top_titles')
        with open(out_file, 'w') as f:
            json.dump(questions_copy, f)

        if evaluate:
            eval_questions(questions_copy)
示例#11
0
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument(
        'output',
        type=str,
        help="Store the per-paragraph results in csv format in this file, "
        "or the json prediction if in test mode")
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    parser.add_argument('-a',
                        '--answer_bound',
                        type=int,
                        default=8,
                        help="Max answer span length")
    parser.add_argument('-c',
                        '--corpus',
                        choices=[
                            "distractors", "gold", "hotpot_file",
                            "retrieval_file", "top_titles"
                        ],
                        default="distractors")
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=None,
                        help="Max tokens per a paragraph")
    parser.add_argument('--input_file', type=str, default=None)
    parser.add_argument('--docs_file', type=str, default=None)
    parser.add_argument('--num_workers',
                        type=int,
                        default=16,
                        help='Number of workers for tokenizing')
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    parser.add_argument('--no_sp',
                        action="store_true",
                        help="Don't predict supporting facts")
    parser.add_argument('--test_mode',
                        action='store_true',
                        help="produce a prediction file, no answers given")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)
    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    loader = ResourceLoader()

    if args.corpus not in {"distractors", "gold"} and args.input_file is None:
        raise ValueError(
            "Must pass an input file if not using precomputed dataset")

    if args.corpus in {"distractors", "gold"} and args.test_mode:
        raise ValueError(
            "Test mode not available in 'distractors' or 'gold' mode")

    if args.corpus in {"distractors", "gold"}:

        corpus = HotpotQuestions()
        loader = corpus.get_resource_loader()
        questions = corpus.get_dev()

        question_preprocessor = HotpotTextLengthPreprocessorWithSpans(
            args.tokens)
        questions = [
            question_preprocessor.preprocess(x) for x in questions
            if (question_preprocessor.preprocess(x) is not None)
        ]

        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(questions, key=lambda x: x.question_id))
            questions = questions[:args.sample_questions]

        data = HotpotFullQADistractorsDataset(questions, batcher)
        gold_idxs = set(data.gold_idxs)
        if args.corpus == 'gold':
            data.samples = [data.samples[i] for i in data.gold_idxs]
        qid2samples = {}
        qid2idx = {}
        for i, sample in enumerate(data.samples):
            key = sample.question_id
            if key in qid2samples:
                qid2samples[key].append(sample)
                qid2idx[key].append(i)
            else:
                qid2samples[key] = [sample]
                qid2idx[key] = [i]
        questions = []
        print("Ranking pairs...")
        gold_ranks = []
        for qid, samples in tqdm(qid2samples.items()):
            question = " ".join(samples[0].question)
            pars = [" ".join(x.paragraphs[0]) for x in samples]
            ranks = get_paragraph_ranks(question, pars)
            for sample, rank, idx in zip(samples, ranks, qid2idx[qid]):
                questions.append(
                    RankedQAPair(question=sample.question,
                                 paragraphs=sample.paragraphs,
                                 spans=np.zeros((0, 2), dtype=np.int32),
                                 question_id=sample.question_id,
                                 answer=sample.answer,
                                 rank=rank,
                                 q_type=sample.q_type,
                                 sentence_segments=sample.sentence_segments))
                if idx in gold_idxs:
                    gold_ranks.append(rank + 1)
        print(f"Mean rank: {np.mean(gold_ranks)}")
        ranks_counter = Counter(gold_ranks)
        for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
            print(f"Hits at {i}: {ranks_counter[i]}")

    elif args.corpus == 'hotpot_file':  # a hotpot json format input file. We rank the pairs with tf-idf
        with open(args.input_file, 'r') as f:
            hotpot_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(hotpot_data, key=lambda x: x['_id']))
            hotpot_data = hotpot_data[:args.sample_questions]
        title2sentences = {
            context[0]: context[1]
            for q in hotpot_data for context in q['context']
        }
        question_tok_texts = tokenize_texts(
            [q['question'] for q in hotpot_data], num_workers=args.num_workers)
        sentences_tok = tokenize_texts(list(title2sentences.values()),
                                       num_workers=args.num_workers,
                                       sentences=True)
        if args.tokens is not None:
            sentences_tok = [
                truncate_paragraph(p, args.tokens) for p in sentences_tok
            ]
        title2tok_sents = {
            title: sentences
            for title, sentences in zip(title2sentences.keys(), sentences_tok)
        }
        questions = []
        for idx, question in enumerate(tqdm(hotpot_data,
                                            desc='tf-idf ranking')):
            q_titles = [title for title, _ in question['context']]
            par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles)
                         for title2 in q_titles[i + 1:]]
            if len(par_pairs) == 0:
                continue
            ranks = get_paragraph_ranks(question['question'], [
                ' '.join(title2sentences[t1] + title2sentences[t2])
                for t1, t2 in par_pairs
            ])
            for rank, par_pair in zip(ranks, par_pairs):
                sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    sent_tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(title2tok_sents[title])
                    if len(sent) == 0
                ] for title in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(sent_tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['_id'],
                        answer='noanswer'
                        if args.test_mode else question['answer'],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (title,
                             sum(1 for sent in title2tok_sents[title]
                                 if len(sent) > 0)) for title in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles':
        if args.docs_file is None:
            print("Using DB documents")
            doc_db = DocDB(config.DOC_DB, full_docs=False)
        else:
            with open(args.docs_file, 'r') as f:
                docs = json.load(f)
        with open(args.input_file, 'r') as f:
            retrieval_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(retrieval_data, key=lambda x: x['qid']))
            retrieval_data = retrieval_data[:args.sample_questions]

        def parname_to_text(par_name):
            par_title = par_name_to_title(par_name)
            par_num = int(par_name.split('_')[-1])
            if args.docs_file is None:
                return doc_db.get_doc_sentences(par_title)
            return docs[par_title][par_num]

        if args.corpus == 'top_titles':
            print("Top TF-IDF!")
            for q in retrieval_data:
                top_titles = q['top_titles'][:10]
                q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0')
                                        for i, title1 in enumerate(top_titles)
                                        for title2 in top_titles[i + 1:]]

        question_tok_texts = tokenize_texts(
            [q['question'] for q in retrieval_data],
            num_workers=args.num_workers)
        all_parnames = list(
            set([
                parname for q in retrieval_data
                for pair in q['paragraph_pairs'] for parname in pair
            ]))
        texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames],
                                   num_workers=args.num_workers,
                                   sentences=True)
        if args.tokens is not None:
            texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok]
        parname2tok_text = {
            parname: text
            for parname, text in zip(all_parnames, texts_tok)
        }
        questions = []
        for idx, question in enumerate(retrieval_data):
            for rank, par_pair in enumerate(question['paragraph_pairs']):
                tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(parname2tok_text[parname])
                    if len(sent) == 0
                ] for parname in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['qid'],
                        answer='noanswer'
                        if args.test_mode else question['answers'][0],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (par_name_to_title(parname),
                             sum(1 for sent in parname2tok_text[parname]
                                 if len(sent) > 0)) for parname in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    else:
        raise NotImplementedError()

    data = DummyDataset(questions, batcher)
    evaluators = [
        RecordHotpotQAPrediction(args.answer_bound,
                                 True,
                                 sp_prediction=not args.no_sp)
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader,
                              checkpoint, not args.no_ema, 10)[args.corpus]

    print("Saving result")
    output_file = args.output

    df = pd.DataFrame(evaluation.per_sample)

    df.sort_values(["question_id", "rank"], inplace=True, ascending=True)
    group_by = ["question_id"]

    def get_ranked_scores(score_name):
        filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \
            df[df.type == 'bridge'] if "Br" in score_name else df
        target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text'
        target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}"
        return compute_ranked_scores_with_yes_no(
            filtered_df,
            span_q_col="span_question_scores",
            yes_no_q_col="yes_no_question_scores",
            yes_no_scores_col="yes_no_confidence_scores",
            span_scores_col="predicted_score",
            span_target_score=target_score,
            group_cols=group_by)

    if not args.test_mode:
        score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"]
        if not args.no_sp:
            score_names.extend([
                f"{prefix} {name}" for prefix in ['sp', 'joint']
                for name in score_names
            ])

        table = [["N Paragraphs"] + score_names]
        scores = [get_ranked_scores(score_name) for score_name in score_names]
        table += list([str(i + 1), *["%.4f" % x for x in score_vals]]
                      for i, score_vals in enumerate(zip(*scores)))
        print_table(table)

        df.to_csv(output_file, index=False)

    else:
        df_to_pred(df, output_file)