def _build_full_dataset(self):
     samples = []
     for question in self.questions:
         pars_and_scores = list(
             zip(question.supporting_facts + question.distractors,
                 question.gold_scores + question.distractor_scores))
         higher_gold = question.supporting_facts[0] \
             if question.gold_scores[0] >= question.gold_scores[1] else question.supporting_facts[1]
         for p1, score1 in pars_and_scores:
             for p2, score2 in pars_and_scores:
                 first_label, second_label = self._get_labels(
                     is_gold1=p1 in question.supporting_facts,
                     is_gold2=p2 in question.supporting_facts,
                     q_type=question.q_type,
                     are_same=p1 == p2,
                     is_first_higher=higher_gold == p1)
                 samples.append(
                     IterativeQuestionAndParagraphs(
                         question=question.question_tokens,
                         paragraphs=[
                             flatten_iterable(p1.sentences),
                             flatten_iterable(p2.sentences)
                         ],
                         first_label=first_label,
                         second_label=second_label,
                         question_id=question.question_id,
                         q_type=question.q_type,
                         sentence_segments=[
                             get_segments_from_sentences(s)
                             for s in [p1.sentences, p2.sentences]
                         ]))
     return samples
Ejemplo n.º 2
0
    def run_evaluators(self,
                       sess: tf.Session,
                       dataset: Dataset,
                       name,
                       n_sample=None,
                       feed_dict=None) -> Evaluation:
        all_tensors_needed = list(
            set(flatten_iterable(x.values() for x in self.tensors_needed)))

        tensors = {x: [] for x in all_tensors_needed}

        if n_sample is None:
            batches, n_batches = dataset.get_epoch(), len(dataset)
        else:
            batches, n_batches = dataset.get_samples(n_sample)

        data_used = []

        for batch in tqdm(batches, total=n_batches, desc=name, ncols=80):
            feed_dict = self.model.encode(batch, is_train=False)
            output = sess.run(all_tensors_needed, feed_dict=feed_dict)
            data_used += batch
            for i in range(len(all_tensors_needed)):
                tensors[all_tensors_needed[i]].append(output[i])

        # flatten the input
        for k in all_tensors_needed:
            v = tensors[k]
            if len(k.shape) == 0:
                v = np.array(v)  # List of scalars
            elif any(x is None for x in k.shape.as_list()):
                # Variable sized tensors, so convert to flat python-list
                v = flatten_iterable(v)
            else:
                v = np.concatenate(v, axis=0)  # concat along the batch dim
            tensors[k] = v

        percent_filtered = dataset.percent_filtered()
        if percent_filtered is None:
            true_len = len(data_used)
        else:
            true_len = len(data_used) * 1 / (1 - percent_filtered)

        combined = None
        for ev, needed in zip(self.evaluators, self.tensors_needed):
            args = {k: tensors[v] for k, v in needed.items()}
            evaluation = ev.evaluate(data_used, true_len, **args)
            if evaluation is None:
                raise ValueError(ev)
            if combined is None:
                combined = evaluation
            else:
                combined.add(evaluation)

        return combined
 def preprocess(self, question: HotpotQuestion):
     for par in question.distractors:
         while len(flatten_iterable(par.sentences)) > self.num_tokens_th:
             par.sentences = par.sentences[:-1]
     for par in question.supporting_facts:
         while len(flatten_iterable(par.sentences)) > self.num_tokens_th:
             if (len(par.sentences) - 1) in par.supporting_sentence_ids:
                 print(
                     "Warning: supporting fact above threshold. removing sample"
                 )
                 return None
             par.sentences = par.sentences[:-1]
     return question
Ejemplo n.º 4
0
def hotpot_question_to_relevance_question(
        hotpot_question: HotpotQuestion) -> RelevanceQuestion:
    return RelevanceQuestion(dataset_name='hotpot',
                             question_id=hotpot_question.question_id,
                             question_tokens=hotpot_question.question_tokens,
                             supporting_facts=[
                                 flatten_iterable(x.sentences)
                                 for x in hotpot_question.supporting_facts
                             ],
                             distractors=[
                                 flatten_iterable(x.sentences)
                                 for x in hotpot_question.distractors
                             ])
Ejemplo n.º 5
0
def evaluate_question_detector(questions: List[HotpotQuestion], word_tokenize, detector,
                               reference_detector=None, compute_f1s=False):
    """ Just for debugging """
    n_no_docs = 0
    answer_per_q = []
    answer_f1s = []

    for question_ix, q in enumerate(tqdm(questions)):
        if q.answer in {'yes', 'no'} and q.q_type == 'comparison':
            continue
        tokenized_aliases = [word_tokenize(q.answer)]
        detector.set_question(tokenized_aliases)

        output = []
        for i, par in enumerate(q.supporting_facts):

            for s ,e in detector.any_found(par.sentences):
                output.append((i, s, e))

            if len(output) == 0 and reference_detector is not None:
                if reference_detector is not None:
                    reference_detector.set_question(tokenized_aliases)
                    detected = []
                    for j, par in enumerate(q.supporting_facts):
                        for s, e in reference_detector.any_found(par.sentences):
                            detected.append((j, s, e))

                    if len(detected) > 0:
                        print("Found a difference")
                        print(q.answer.normalized_aliases)
                        print(tokenized_aliases)
                        for p, s, e in detected:
                            token = flatten_iterable(q.supporting_facts[p].sentences)[s:e]
                            print(token)

        answer_per_q.append(output)

        if compute_f1s:
            f1s = []
            for p, s, e in output:
                token = flatten_iterable(q.supporting_facts[p].sentences)[s:e]
                answer = normalize_answer(" ".join(token))
                f1, _, _ = f1_score(answer, normalize_answer(q.answer))
                f1s.append(f1)
            answer_f1s.append(f1s)

    n_answers = sum(len(x) for x in answer_per_q)
    print("Found %d answers (av %.4f)" % (n_answers, n_answers / len(answer_per_q)))
    print("%.4f docs have answers" % np.mean([len(x) > 0 for x in answer_per_q]))
    if len(answer_f1s) > 0:
        print("Average f1 is %.4f" % np.mean(flatten_iterable(answer_f1s)))
Ejemplo n.º 6
0
 def get_word_counts(self):
     count = Counter()
     for q in self.questions:
         count.update(q.question_tokens)
         for para in (q.distractors + q.supporting_facts):
             count.update(flatten_iterable(para.sentences))
     return count
Ejemplo n.º 7
0
    def any_found(self, para: List[List[str]]):
        # Normalize the paragraph
        words = [w.lower().strip(self.strip) for w in flatten_iterable(para)]
        occurances = []
        for answer_ix, answer in enumerate(self.answer_tokens):
            # Locations where the first word occurs
            if len(answer) == 0:
                continue
            word_starts = [i for i, w in enumerate(words) if answer[0] == w]
            n_tokens = len(answer)

            # Advance forward until we find all the words, skipping over articles
            for start in word_starts:
                end = start + 1
                ans_token = 1
                while ans_token < n_tokens and end < len(words):
                    next = words[end]
                    if answer[ans_token] == next:
                        ans_token += 1
                        end += 1
                    elif next in self.skip:
                        end += 1
                    else:
                        break
                if n_tokens == ans_token:
                    occurances.append((start, end))
        return list(set(occurances))
Ejemplo n.º 8
0
def merge_paragraphs(paragraphs: List[HotpotParagraph], spans: List[np.ndarray], supporting_fact_idxs: List[List[int]])\
        -> Tuple[List[str], np.ndarray, List[int], np.ndarray]:
    # todo this supporting fact fixing is a hack but easy to handle here so will do for now
    for i in range(len(paragraphs)):
        supporting_fact_idxs[i] = [
            x for x in supporting_fact_idxs[i]
            if x < len(paragraphs[i].sentences)
        ]
    merged_text = []
    merged_spans = np.zeros((0, 2), dtype=np.int32)
    merged_sp_idxs = np.zeros(0, dtype=np.int32)
    merged_sentences = []
    for par, par_spans, par_facts in zip(paragraphs, spans,
                                         supporting_fact_idxs):
        merged_spans = np.concatenate(
            [merged_spans, par_spans + len(merged_text)])
        merged_sp_idxs = np.concatenate([
            merged_sp_idxs,
            np.array(par_facts, dtype=np.int32) + len(merged_sentences)
        ])
        merged_sentences.extend(par.sentences)
        merged_text.extend(flatten_iterable(par.sentences))
    segments, merged_sp_idxs = get_segments_from_sentences_fix_sup(
        merged_sentences, merged_sp_idxs)
    return merged_text, merged_spans, segments, merged_sp_idxs
Ejemplo n.º 9
0
 def get_vocab(self):
     voc = set()
     for q in self.questions:
         voc.update(q.question_tokens)
         for para in (q.distractors + q.supporting_facts):
             voc.update(flatten_iterable(para.sentences))
     return voc
 def _build_gold_samples(self):
     gold_samples = []
     for question in self.questions:
         pars = [
             flatten_iterable(question.supporting_facts[0].sentences),
             flatten_iterable(question.supporting_facts[1].sentences)
         ]
         self.random.shuffle(pars)
         gold_samples.append(
             BinaryQuestionAndParagraphs(question.question_tokens,
                                         pars,
                                         1,
                                         num_distractors=0,
                                         question_id=question.question_id,
                                         q_type=question.q_type))
     return gold_samples
    def _sample_first_gold_second_false(self, qid):
        question = self.qid2question[qid]
        rand_par_other_q = self._sample_rand_par_other_q(qid)
        if question.q_type == 'comparison' or self.bridge_as_comparison:
            first_gold_par = question.supporting_facts[self.random.randint(2)]
        else:
            if not self.label_by_span:
                first_gold_idx = 0 if question.gold_scores[
                    0] > question.gold_scores[1] else 1
                first_gold_par = question.supporting_facts[first_gold_idx]
            else:
                gold_idxs = self._get_no_span_containing_golds(
                    question.question_id)
                if len(gold_idxs) == 1:
                    first_gold_par = question.supporting_facts[gold_idxs[0]]
                else:
                    first_gold_par = question.supporting_facts[
                        self.random.randint(2)]

        rand_par = self.random.choice([
            rand_par_other_q, first_gold_par,
            self.random.choice(question.distractors)
        ],
                                      p=[0.05, 0.1, 0.85])

        pars = [
            flatten_iterable(first_gold_par.sentences),
            flatten_iterable(rand_par.sentences)
        ]
        segs = [
            get_segments_from_sentences(first_gold_par.sentences),
            get_segments_from_sentences(rand_par.sentences)
        ]
        return IterativeQuestionAndParagraphs(
            question=question.question_tokens,
            paragraphs=pars,
            first_label=1,
            second_label=0,
            question_id=question.question_id,
            q_type=question.q_type,
            sentence_segments=segs)
 def _build_full_dataset(self):
     samples = []
     for question in self.questions:
         for i, p1 in enumerate(question.distractors +
                                question.supporting_facts):
             for p2 in (question.distractors +
                        question.supporting_facts)[i + 1:]:
                 label = 1 if ((p1 in question.supporting_facts) and
                               (p2 in question.supporting_facts)) else 0
                 num_distractors = sum([
                     p1 in question.distractors, p2 in question.distractors
                 ])
                 samples.append(
                     BinaryQuestionAndParagraphs(
                         question.question_tokens, [
                             flatten_iterable(p1.sentences),
                             flatten_iterable(p2.sentences)
                         ],
                         label,
                         num_distractors=num_distractors,
                         question_id=question.question_id,
                         q_type=question.q_type))
     return samples
Ejemplo n.º 13
0
def assign_scores(question: HotpotQuestion):
    question_spvec = PROCESS_RANKER.text2spvec(question.question_tokens,
                                               tokenized=True)
    paragraphs = [
        flatten_iterable(par.sentences)
        for par in (question.supporting_facts + question.distractors)
    ]
    pars_spvecs = [
        PROCESS_RANKER.text2spvec(x, tokenized=True) for x in paragraphs
    ]
    pars_spvecs = vstack(pars_spvecs)
    scores = pars_spvecs.dot(question_spvec.toarray().squeeze(axis=0))
    question.gold_scores = scores[:len(question.supporting_facts)].tolist()
    question.distractor_scores = scores[len(question.supporting_facts
                                            ):].tolist()
    return question
 def _sample_false_1(self, qid):
     """ False sample of type 1: all distractors.
     No sampling from other question here, as I think it's less effective in this case"""
     question = self.qid2question[qid]
     two_distractors = self.random.choice(question.distractors,
                                          size=2,
                                          replace=False)
     pars = [flatten_iterable(x.sentences) for x in two_distractors]
     segs = [
         get_segments_from_sentences(x.sentences) for x in two_distractors
     ]
     return IterativeQuestionAndParagraphs(
         question=question.question_tokens,
         paragraphs=pars,
         first_label=0,
         second_label=0,
         question_id=question.question_id,
         q_type=question.q_type,
         sentence_segments=segs)
Ejemplo n.º 15
0
def build_spec(batch_size: int, max_batch_size: int, num_contexts: int,
               data: List[RelevanceQuestion]) -> QuestionAndParagraphsSpec:
    max_ques_size = 0
    max_word_size = 0
    max_para_size = 0
    max_num_contexts = num_contexts
    for data_point in data:
        contexts = data_point.distractors + data_point.supporting_facts
        # max_num_contexts = num_contexts
        max_word_size = max(
            max_word_size,
            max(len(word) for word in flatten_iterable(contexts)))
        max_para_size = max(max_para_size,
                            max(len(context) for context in contexts))
        max_ques_size = max(max_ques_size, len(data_point.question_tokens))
        max_word_size = max(
            max_word_size,
            max(len(word) for word in data_point.question_tokens))
    return QuestionAndParagraphsSpec(batch_size, max_num_contexts,
                                     max_ques_size, max_para_size,
                                     max_word_size, max_batch_size)
 def _sample_false_2(self, qid):
     """ False sample of type 2: first distractor, second one of supporting facts """
     question = self.qid2question[qid]
     rand_par_other_q = self._sample_rand_par_other_q(qid)
     distractor = self.random.choice(
         [self.random.choice(question.distractors), rand_par_other_q],
         p=[0.9, 0.1])
     gold = self.random.choice(question.supporting_facts)
     pars = [flatten_iterable(x.sentences) for x in [distractor, gold]]
     segs = [
         get_segments_from_sentences(x.sentences)
         for x in [distractor, gold]
     ]
     return IterativeQuestionAndParagraphs(
         question=question.question_tokens,
         paragraphs=pars,
         first_label=0,
         second_label=0,
         question_id=question.question_id,
         q_type=question.q_type,
         sentence_segments=segs)
Ejemplo n.º 17
0
 def get_vocab(self):
     """ get all-lower cased unique words for this corpus, includes train/dev/test files """
     if not exists(self.dir):
         self.dir = join(config.CORPUS_DIR, self.NAME)
     voc_file = join(self.dir, self.VOCAB_FILE)
     if exists(voc_file):
         with open(voc_file, "r") as f:
             return [x.rstrip() for x in f]
     else:
         voc = set()
         for fn in [self.get_train, self.get_dev, self.get_test]:
             for question in fn():
                 voc.update(x.lower() for x in question.question_tokens)
                 for para in (question.distractors +
                              question.supporting_facts):
                     voc.update(x.lower()
                                for x in flatten_iterable(para.sentences))
         voc_list = sorted(list(voc))
         with open(voc_file, "w") as f:
             for word in voc_list:
                 f.write(word)
                 f.write("\n")
         return voc_list
 def _build_gold_samples(self):
     gold_samples = []
     for question in self.questions:
         if question.q_type == 'comparison' or self.bridge_as_comparison:
             pars_order = [0, 1]
             self.random.shuffle(pars_order)
         else:
             if not self.label_by_span:
                 pars_order = [0, 1] if question.gold_scores[
                     0] > question.gold_scores[1] else [1, 0]
             else:
                 gold_idxs = self._get_no_span_containing_golds(
                     question.question_id)
                 pars_order = [0, 1] if 0 in gold_idxs else [1, 0]
                 if len(
                         gold_idxs
                 ) != 1:  # either both contain the answer or both don't contain, so regarded equal
                     self.random.shuffle(pars_order)
         pars = [
             flatten_iterable(question.supporting_facts[i].sentences)
             for i in pars_order
         ]
         sentence_segs = [
             get_segments_from_sentences(
                 question.supporting_facts[i].sentences) for i in pars_order
         ]
         gold_samples.append(
             IterativeQuestionAndParagraphs(
                 question.question_tokens,
                 pars,
                 first_label=1,
                 second_label=1,
                 question_id=question.question_id,
                 q_type=question.q_type,
                 sentence_segments=sentence_segs))
     return gold_samples
Ejemplo n.º 19
0
def get_answers(question: str, top_title_tuples: List[Tuple[str]]):
    question_tok = tokenize_words(question)
    all_titles = list(set([title for titles in top_title_tuples for title in titles]))
    texts_tok = [x for x in tok_workers.imap(tokenize_sentences, [db.get_doc_sentences(title) for title in all_titles])]
    title2tok_sents = {title: tok_sents for title, tok_sents in zip(all_titles, texts_tok)}
    questions = []
    for rank, title_tuple in enumerate(top_title_tuples):
        tokenized_sents = [sent for t in title_tuple for sent in title2tok_sents[t]]
        sentence_segments, _ = get_segments_from_sentences_fix_sup(tokenized_sents, np.zeros(0))
        missing_sent_idx = [[i for i, sent in enumerate(title2tok_sents[title]) if len(sent) == 0]
                            for title in title_tuple]
        questions.append(RankedQAPair(question=question_tok, paragraphs=[flatten_iterable(tokenized_sents)],
                                      spans=np.zeros((0, 2), dtype=np.int32), question_id='bla',
                                      answer='noanswer',
                                      rank=rank,
                                      q_type='n/a', sentence_segments=[sentence_segments],
                                      par_titles_num_sents=
                                      [(title,
                                        sum(1 for sent in title2tok_sents[title] if len(sent) > 0))
                                       for title in title_tuple],
                                      missing_sent_idxs=missing_sent_idx,
                                      true_sp=[]))
    data = DummyDataset(questions, batcher)
    evaluation = evaluator_runner.run_evaluators(qa_sess, data, 'bla', None, {}, disable_tqdm=True)
    df = pd.DataFrame(evaluation.per_sample)

    df.sort_values(["question_id", "rank"], inplace=True, ascending=True)
    answer_dict, sp_dict = df_to_pred(df, None, return_results=True)
    # sp_raw = [db.get_doc_sentences(title)[idx] for title, idx in sp_dict['bla']]
    title2idxs = {}
    for title, idx in sp_dict['bla']:
        if title not in title2idxs:
            title2idxs[title] = []
        title2idxs[title].append(idx)
    sp_titles_idxs = [(title, idxs) for title, idxs in title2idxs.items()]
    return answer_dict['bla'], sp_titles_idxs
Ejemplo n.º 20
0
 def num_tokens(self):
     return len(flatten_iterable(self.sentences))
Ejemplo n.º 21
0
    def run_evaluators(self,
                       sess: tf.Session,
                       dataset,
                       name,
                       n_sample,
                       feed_dict,
                       disable_tqdm=False) -> Evaluation:
        all_tensors_needed = list(
            set(flatten_iterable(x.values() for x in self.tensors_needed)))

        tensors = {x: [] for x in all_tensors_needed}

        data_used = []
        if n_sample is None:
            batches, n_batches = dataset.get_epoch(), len(dataset)
        else:
            batches, n_batches = dataset.get_samples(n_sample)

        def enqueue_eval():
            try:
                for data in batches:
                    encoded = self.model.encode(data, False)
                    data_used.append(data)
                    sess.run(self.enqueue_op, encoded)
            except Exception as e:
                sess.run(self.close_queue)  # Crash the main thread
                raise e
            # we should run out of batches and exit gracefully

        th = Thread(target=enqueue_eval)

        th.daemon = True
        th.start()
        for _ in tqdm(range(n_batches),
                      total=n_batches,
                      desc=name,
                      ncols=80,
                      disable=disable_tqdm):
            output = sess.run(all_tensors_needed, feed_dict=feed_dict)
            for i in range(len(all_tensors_needed)):
                tensors[all_tensors_needed[i]].append(output[i])
        th.join()

        if sess.run(self.queue_size) != 0:
            raise RuntimeError("All batches should be been consumed")

        # flatten the input
        for k in all_tensors_needed:
            v = tensors[k]
            if len(k.shape) == 0:
                v = np.array(v)  # List of scalars -> array
            elif any(x is None for x in k.shape.as_list()[1:]):
                # Variable sized tensors, so convert to flat python-list
                v = flatten_iterable(v)
            else:
                v = np.concatenate(v, axis=0)  # concat along the batch dim
            tensors[k] = v

        # flatten the data if it consists of batches
        if isinstance(data_used[0], List):
            data_used = flatten_iterable(data_used)

        if dataset.percent_filtered() is None:
            true_len = len(data_used)
        else:
            true_len = len(data_used) * 1 / (1 - dataset.percent_filtered())

        combined = None
        for ev, needed in zip(self.evaluators, self.tensors_needed):
            args = {k: tensors[v] for k, v in needed.items()}
            evaluation = ev.evaluate(data_used, true_len, **args)
            if combined is None:
                combined = evaluation
            else:
                combined.add(evaluation)

        return combined
 def get_epoch(self, new_epoch=True):
     if self.fixed_dataset:
         new_epoch = False
     if not new_epoch and self.epoch_samples is not None:
         return self.batcher.get_epoch(self.epoch_samples)
     false_samples = []
     for question in self.questions:
         two_distractors = [
             flatten_iterable(x.sentences) for x in self.random.choice(
                 question.distractors, size=2, replace=False)
         ]
         true_and_false_1 = [
             flatten_iterable(question.supporting_facts[0].sentences),
             two_distractors[0]
         ]
         true_and_false_2 = [
             flatten_iterable(question.supporting_facts[1].sentences),
             two_distractors[1]
         ]
         self.random.shuffle(true_and_false_1)
         self.random.shuffle(true_and_false_2)
         false_samples.append(
             BinaryQuestionAndParagraphs(question.question_tokens,
                                         true_and_false_1,
                                         0,
                                         num_distractors=1,
                                         question_id=question.question_id,
                                         q_type=question.q_type))
         false_samples.append(
             BinaryQuestionAndParagraphs(question.question_tokens,
                                         true_and_false_2,
                                         0,
                                         num_distractors=1,
                                         question_id=question.question_id,
                                         q_type=question.q_type))
         false_samples.append(
             BinaryQuestionAndParagraphs(question.question_tokens, [
                 flatten_iterable(x.sentences) for x in self.random.choice(
                     question.distractors, size=2, replace=False)
             ],
                                         0,
                                         num_distractors=2,
                                         question_id=question.question_id,
                                         q_type=question.q_type))
         if self.add_gold_distractor:
             rand_q_idx = self.random.randint(len(self.gold_samples))
             while self.gold_samples[
                     rand_q_idx].question_id == question.question_id:
                 rand_q_idx = self.random.randint(len(self.gold_samples))
             selected_q = self.gold_samples[rand_q_idx]
             self.random.shuffle(selected_q.paragraphs)
             false_samples.append(
                 BinaryQuestionAndParagraphs(
                     question.question_tokens,
                     selected_q.paragraphs,
                     label=0,
                     num_distractors=2,
                     question_id=question.question_id,
                     q_type=question.q_type))
     for gold in self.gold_samples:
         self.random.shuffle(gold.paragraphs)
     self.epoch_samples = self.gold_samples + false_samples
     np.random.shuffle(self.epoch_samples)
     return self.batcher.get_epoch(self.epoch_samples)
Ejemplo n.º 23
0
def truncate_paragraph(tokenized_sentences: List[List[str]], num_tokens):
    while len(flatten_iterable(tokenized_sentences)) > num_tokens:
        tokenized_sentences = tokenized_sentences[:-1]
    return tokenized_sentences
Ejemplo n.º 24
0
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument(
        'output',
        type=str,
        help="Store the per-paragraph results in csv format in this file, "
        "or the json prediction if in test mode")
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    parser.add_argument('-a',
                        '--answer_bound',
                        type=int,
                        default=8,
                        help="Max answer span length")
    parser.add_argument('-c',
                        '--corpus',
                        choices=[
                            "distractors", "gold", "hotpot_file",
                            "retrieval_file", "top_titles"
                        ],
                        default="distractors")
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=None,
                        help="Max tokens per a paragraph")
    parser.add_argument('--input_file', type=str, default=None)
    parser.add_argument('--docs_file', type=str, default=None)
    parser.add_argument('--num_workers',
                        type=int,
                        default=16,
                        help='Number of workers for tokenizing')
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    parser.add_argument('--no_sp',
                        action="store_true",
                        help="Don't predict supporting facts")
    parser.add_argument('--test_mode',
                        action='store_true',
                        help="produce a prediction file, no answers given")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)
    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    loader = ResourceLoader()

    if args.corpus not in {"distractors", "gold"} and args.input_file is None:
        raise ValueError(
            "Must pass an input file if not using precomputed dataset")

    if args.corpus in {"distractors", "gold"} and args.test_mode:
        raise ValueError(
            "Test mode not available in 'distractors' or 'gold' mode")

    if args.corpus in {"distractors", "gold"}:

        corpus = HotpotQuestions()
        loader = corpus.get_resource_loader()
        questions = corpus.get_dev()

        question_preprocessor = HotpotTextLengthPreprocessorWithSpans(
            args.tokens)
        questions = [
            question_preprocessor.preprocess(x) for x in questions
            if (question_preprocessor.preprocess(x) is not None)
        ]

        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(questions, key=lambda x: x.question_id))
            questions = questions[:args.sample_questions]

        data = HotpotFullQADistractorsDataset(questions, batcher)
        gold_idxs = set(data.gold_idxs)
        if args.corpus == 'gold':
            data.samples = [data.samples[i] for i in data.gold_idxs]
        qid2samples = {}
        qid2idx = {}
        for i, sample in enumerate(data.samples):
            key = sample.question_id
            if key in qid2samples:
                qid2samples[key].append(sample)
                qid2idx[key].append(i)
            else:
                qid2samples[key] = [sample]
                qid2idx[key] = [i]
        questions = []
        print("Ranking pairs...")
        gold_ranks = []
        for qid, samples in tqdm(qid2samples.items()):
            question = " ".join(samples[0].question)
            pars = [" ".join(x.paragraphs[0]) for x in samples]
            ranks = get_paragraph_ranks(question, pars)
            for sample, rank, idx in zip(samples, ranks, qid2idx[qid]):
                questions.append(
                    RankedQAPair(question=sample.question,
                                 paragraphs=sample.paragraphs,
                                 spans=np.zeros((0, 2), dtype=np.int32),
                                 question_id=sample.question_id,
                                 answer=sample.answer,
                                 rank=rank,
                                 q_type=sample.q_type,
                                 sentence_segments=sample.sentence_segments))
                if idx in gold_idxs:
                    gold_ranks.append(rank + 1)
        print(f"Mean rank: {np.mean(gold_ranks)}")
        ranks_counter = Counter(gold_ranks)
        for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
            print(f"Hits at {i}: {ranks_counter[i]}")

    elif args.corpus == 'hotpot_file':  # a hotpot json format input file. We rank the pairs with tf-idf
        with open(args.input_file, 'r') as f:
            hotpot_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(hotpot_data, key=lambda x: x['_id']))
            hotpot_data = hotpot_data[:args.sample_questions]
        title2sentences = {
            context[0]: context[1]
            for q in hotpot_data for context in q['context']
        }
        question_tok_texts = tokenize_texts(
            [q['question'] for q in hotpot_data], num_workers=args.num_workers)
        sentences_tok = tokenize_texts(list(title2sentences.values()),
                                       num_workers=args.num_workers,
                                       sentences=True)
        if args.tokens is not None:
            sentences_tok = [
                truncate_paragraph(p, args.tokens) for p in sentences_tok
            ]
        title2tok_sents = {
            title: sentences
            for title, sentences in zip(title2sentences.keys(), sentences_tok)
        }
        questions = []
        for idx, question in enumerate(tqdm(hotpot_data,
                                            desc='tf-idf ranking')):
            q_titles = [title for title, _ in question['context']]
            par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles)
                         for title2 in q_titles[i + 1:]]
            if len(par_pairs) == 0:
                continue
            ranks = get_paragraph_ranks(question['question'], [
                ' '.join(title2sentences[t1] + title2sentences[t2])
                for t1, t2 in par_pairs
            ])
            for rank, par_pair in zip(ranks, par_pairs):
                sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    sent_tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(title2tok_sents[title])
                    if len(sent) == 0
                ] for title in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(sent_tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['_id'],
                        answer='noanswer'
                        if args.test_mode else question['answer'],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (title,
                             sum(1 for sent in title2tok_sents[title]
                                 if len(sent) > 0)) for title in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles':
        if args.docs_file is None:
            print("Using DB documents")
            doc_db = DocDB(config.DOC_DB, full_docs=False)
        else:
            with open(args.docs_file, 'r') as f:
                docs = json.load(f)
        with open(args.input_file, 'r') as f:
            retrieval_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(retrieval_data, key=lambda x: x['qid']))
            retrieval_data = retrieval_data[:args.sample_questions]

        def parname_to_text(par_name):
            par_title = par_name_to_title(par_name)
            par_num = int(par_name.split('_')[-1])
            if args.docs_file is None:
                return doc_db.get_doc_sentences(par_title)
            return docs[par_title][par_num]

        if args.corpus == 'top_titles':
            print("Top TF-IDF!")
            for q in retrieval_data:
                top_titles = q['top_titles'][:10]
                q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0')
                                        for i, title1 in enumerate(top_titles)
                                        for title2 in top_titles[i + 1:]]

        question_tok_texts = tokenize_texts(
            [q['question'] for q in retrieval_data],
            num_workers=args.num_workers)
        all_parnames = list(
            set([
                parname for q in retrieval_data
                for pair in q['paragraph_pairs'] for parname in pair
            ]))
        texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames],
                                   num_workers=args.num_workers,
                                   sentences=True)
        if args.tokens is not None:
            texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok]
        parname2tok_text = {
            parname: text
            for parname, text in zip(all_parnames, texts_tok)
        }
        questions = []
        for idx, question in enumerate(retrieval_data):
            for rank, par_pair in enumerate(question['paragraph_pairs']):
                tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(parname2tok_text[parname])
                    if len(sent) == 0
                ] for parname in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['qid'],
                        answer='noanswer'
                        if args.test_mode else question['answers'][0],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (par_name_to_title(parname),
                             sum(1 for sent in parname2tok_text[parname]
                                 if len(sent) > 0)) for parname in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    else:
        raise NotImplementedError()

    data = DummyDataset(questions, batcher)
    evaluators = [
        RecordHotpotQAPrediction(args.answer_bound,
                                 True,
                                 sp_prediction=not args.no_sp)
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader,
                              checkpoint, not args.no_ema, 10)[args.corpus]

    print("Saving result")
    output_file = args.output

    df = pd.DataFrame(evaluation.per_sample)

    df.sort_values(["question_id", "rank"], inplace=True, ascending=True)
    group_by = ["question_id"]

    def get_ranked_scores(score_name):
        filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \
            df[df.type == 'bridge'] if "Br" in score_name else df
        target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text'
        target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}"
        return compute_ranked_scores_with_yes_no(
            filtered_df,
            span_q_col="span_question_scores",
            yes_no_q_col="yes_no_question_scores",
            yes_no_scores_col="yes_no_confidence_scores",
            span_scores_col="predicted_score",
            span_target_score=target_score,
            group_cols=group_by)

    if not args.test_mode:
        score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"]
        if not args.no_sp:
            score_names.extend([
                f"{prefix} {name}" for prefix in ['sp', 'joint']
                for name in score_names
            ])

        table = [["N Paragraphs"] + score_names]
        scores = [get_ranked_scores(score_name) for score_name in score_names]
        table += list([str(i + 1), *["%.4f" % x for x in score_vals]]
                      for i, score_vals in enumerate(zip(*scores)))
        print_table(table)

        df.to_csv(output_file, index=False)

    else:
        df_to_pred(df, output_file)