def _build_full_dataset(self): samples = [] for question in self.questions: pars_and_scores = list( zip(question.supporting_facts + question.distractors, question.gold_scores + question.distractor_scores)) higher_gold = question.supporting_facts[0] \ if question.gold_scores[0] >= question.gold_scores[1] else question.supporting_facts[1] for p1, score1 in pars_and_scores: for p2, score2 in pars_and_scores: first_label, second_label = self._get_labels( is_gold1=p1 in question.supporting_facts, is_gold2=p2 in question.supporting_facts, q_type=question.q_type, are_same=p1 == p2, is_first_higher=higher_gold == p1) samples.append( IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=[ flatten_iterable(p1.sentences), flatten_iterable(p2.sentences) ], first_label=first_label, second_label=second_label, question_id=question.question_id, q_type=question.q_type, sentence_segments=[ get_segments_from_sentences(s) for s in [p1.sentences, p2.sentences] ])) return samples
def run_evaluators(self, sess: tf.Session, dataset: Dataset, name, n_sample=None, feed_dict=None) -> Evaluation: all_tensors_needed = list( set(flatten_iterable(x.values() for x in self.tensors_needed))) tensors = {x: [] for x in all_tensors_needed} if n_sample is None: batches, n_batches = dataset.get_epoch(), len(dataset) else: batches, n_batches = dataset.get_samples(n_sample) data_used = [] for batch in tqdm(batches, total=n_batches, desc=name, ncols=80): feed_dict = self.model.encode(batch, is_train=False) output = sess.run(all_tensors_needed, feed_dict=feed_dict) data_used += batch for i in range(len(all_tensors_needed)): tensors[all_tensors_needed[i]].append(output[i]) # flatten the input for k in all_tensors_needed: v = tensors[k] if len(k.shape) == 0: v = np.array(v) # List of scalars elif any(x is None for x in k.shape.as_list()): # Variable sized tensors, so convert to flat python-list v = flatten_iterable(v) else: v = np.concatenate(v, axis=0) # concat along the batch dim tensors[k] = v percent_filtered = dataset.percent_filtered() if percent_filtered is None: true_len = len(data_used) else: true_len = len(data_used) * 1 / (1 - percent_filtered) combined = None for ev, needed in zip(self.evaluators, self.tensors_needed): args = {k: tensors[v] for k, v in needed.items()} evaluation = ev.evaluate(data_used, true_len, **args) if evaluation is None: raise ValueError(ev) if combined is None: combined = evaluation else: combined.add(evaluation) return combined
def preprocess(self, question: HotpotQuestion): for par in question.distractors: while len(flatten_iterable(par.sentences)) > self.num_tokens_th: par.sentences = par.sentences[:-1] for par in question.supporting_facts: while len(flatten_iterable(par.sentences)) > self.num_tokens_th: if (len(par.sentences) - 1) in par.supporting_sentence_ids: print( "Warning: supporting fact above threshold. removing sample" ) return None par.sentences = par.sentences[:-1] return question
def hotpot_question_to_relevance_question( hotpot_question: HotpotQuestion) -> RelevanceQuestion: return RelevanceQuestion(dataset_name='hotpot', question_id=hotpot_question.question_id, question_tokens=hotpot_question.question_tokens, supporting_facts=[ flatten_iterable(x.sentences) for x in hotpot_question.supporting_facts ], distractors=[ flatten_iterable(x.sentences) for x in hotpot_question.distractors ])
def evaluate_question_detector(questions: List[HotpotQuestion], word_tokenize, detector, reference_detector=None, compute_f1s=False): """ Just for debugging """ n_no_docs = 0 answer_per_q = [] answer_f1s = [] for question_ix, q in enumerate(tqdm(questions)): if q.answer in {'yes', 'no'} and q.q_type == 'comparison': continue tokenized_aliases = [word_tokenize(q.answer)] detector.set_question(tokenized_aliases) output = [] for i, par in enumerate(q.supporting_facts): for s ,e in detector.any_found(par.sentences): output.append((i, s, e)) if len(output) == 0 and reference_detector is not None: if reference_detector is not None: reference_detector.set_question(tokenized_aliases) detected = [] for j, par in enumerate(q.supporting_facts): for s, e in reference_detector.any_found(par.sentences): detected.append((j, s, e)) if len(detected) > 0: print("Found a difference") print(q.answer.normalized_aliases) print(tokenized_aliases) for p, s, e in detected: token = flatten_iterable(q.supporting_facts[p].sentences)[s:e] print(token) answer_per_q.append(output) if compute_f1s: f1s = [] for p, s, e in output: token = flatten_iterable(q.supporting_facts[p].sentences)[s:e] answer = normalize_answer(" ".join(token)) f1, _, _ = f1_score(answer, normalize_answer(q.answer)) f1s.append(f1) answer_f1s.append(f1s) n_answers = sum(len(x) for x in answer_per_q) print("Found %d answers (av %.4f)" % (n_answers, n_answers / len(answer_per_q))) print("%.4f docs have answers" % np.mean([len(x) > 0 for x in answer_per_q])) if len(answer_f1s) > 0: print("Average f1 is %.4f" % np.mean(flatten_iterable(answer_f1s)))
def get_word_counts(self): count = Counter() for q in self.questions: count.update(q.question_tokens) for para in (q.distractors + q.supporting_facts): count.update(flatten_iterable(para.sentences)) return count
def any_found(self, para: List[List[str]]): # Normalize the paragraph words = [w.lower().strip(self.strip) for w in flatten_iterable(para)] occurances = [] for answer_ix, answer in enumerate(self.answer_tokens): # Locations where the first word occurs if len(answer) == 0: continue word_starts = [i for i, w in enumerate(words) if answer[0] == w] n_tokens = len(answer) # Advance forward until we find all the words, skipping over articles for start in word_starts: end = start + 1 ans_token = 1 while ans_token < n_tokens and end < len(words): next = words[end] if answer[ans_token] == next: ans_token += 1 end += 1 elif next in self.skip: end += 1 else: break if n_tokens == ans_token: occurances.append((start, end)) return list(set(occurances))
def merge_paragraphs(paragraphs: List[HotpotParagraph], spans: List[np.ndarray], supporting_fact_idxs: List[List[int]])\ -> Tuple[List[str], np.ndarray, List[int], np.ndarray]: # todo this supporting fact fixing is a hack but easy to handle here so will do for now for i in range(len(paragraphs)): supporting_fact_idxs[i] = [ x for x in supporting_fact_idxs[i] if x < len(paragraphs[i].sentences) ] merged_text = [] merged_spans = np.zeros((0, 2), dtype=np.int32) merged_sp_idxs = np.zeros(0, dtype=np.int32) merged_sentences = [] for par, par_spans, par_facts in zip(paragraphs, spans, supporting_fact_idxs): merged_spans = np.concatenate( [merged_spans, par_spans + len(merged_text)]) merged_sp_idxs = np.concatenate([ merged_sp_idxs, np.array(par_facts, dtype=np.int32) + len(merged_sentences) ]) merged_sentences.extend(par.sentences) merged_text.extend(flatten_iterable(par.sentences)) segments, merged_sp_idxs = get_segments_from_sentences_fix_sup( merged_sentences, merged_sp_idxs) return merged_text, merged_spans, segments, merged_sp_idxs
def get_vocab(self): voc = set() for q in self.questions: voc.update(q.question_tokens) for para in (q.distractors + q.supporting_facts): voc.update(flatten_iterable(para.sentences)) return voc
def _build_gold_samples(self): gold_samples = [] for question in self.questions: pars = [ flatten_iterable(question.supporting_facts[0].sentences), flatten_iterable(question.supporting_facts[1].sentences) ] self.random.shuffle(pars) gold_samples.append( BinaryQuestionAndParagraphs(question.question_tokens, pars, 1, num_distractors=0, question_id=question.question_id, q_type=question.q_type)) return gold_samples
def _sample_first_gold_second_false(self, qid): question = self.qid2question[qid] rand_par_other_q = self._sample_rand_par_other_q(qid) if question.q_type == 'comparison' or self.bridge_as_comparison: first_gold_par = question.supporting_facts[self.random.randint(2)] else: if not self.label_by_span: first_gold_idx = 0 if question.gold_scores[ 0] > question.gold_scores[1] else 1 first_gold_par = question.supporting_facts[first_gold_idx] else: gold_idxs = self._get_no_span_containing_golds( question.question_id) if len(gold_idxs) == 1: first_gold_par = question.supporting_facts[gold_idxs[0]] else: first_gold_par = question.supporting_facts[ self.random.randint(2)] rand_par = self.random.choice([ rand_par_other_q, first_gold_par, self.random.choice(question.distractors) ], p=[0.05, 0.1, 0.85]) pars = [ flatten_iterable(first_gold_par.sentences), flatten_iterable(rand_par.sentences) ] segs = [ get_segments_from_sentences(first_gold_par.sentences), get_segments_from_sentences(rand_par.sentences) ] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=pars, first_label=1, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=segs)
def _build_full_dataset(self): samples = [] for question in self.questions: for i, p1 in enumerate(question.distractors + question.supporting_facts): for p2 in (question.distractors + question.supporting_facts)[i + 1:]: label = 1 if ((p1 in question.supporting_facts) and (p2 in question.supporting_facts)) else 0 num_distractors = sum([ p1 in question.distractors, p2 in question.distractors ]) samples.append( BinaryQuestionAndParagraphs( question.question_tokens, [ flatten_iterable(p1.sentences), flatten_iterable(p2.sentences) ], label, num_distractors=num_distractors, question_id=question.question_id, q_type=question.q_type)) return samples
def assign_scores(question: HotpotQuestion): question_spvec = PROCESS_RANKER.text2spvec(question.question_tokens, tokenized=True) paragraphs = [ flatten_iterable(par.sentences) for par in (question.supporting_facts + question.distractors) ] pars_spvecs = [ PROCESS_RANKER.text2spvec(x, tokenized=True) for x in paragraphs ] pars_spvecs = vstack(pars_spvecs) scores = pars_spvecs.dot(question_spvec.toarray().squeeze(axis=0)) question.gold_scores = scores[:len(question.supporting_facts)].tolist() question.distractor_scores = scores[len(question.supporting_facts ):].tolist() return question
def _sample_false_1(self, qid): """ False sample of type 1: all distractors. No sampling from other question here, as I think it's less effective in this case""" question = self.qid2question[qid] two_distractors = self.random.choice(question.distractors, size=2, replace=False) pars = [flatten_iterable(x.sentences) for x in two_distractors] segs = [ get_segments_from_sentences(x.sentences) for x in two_distractors ] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=pars, first_label=0, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=segs)
def build_spec(batch_size: int, max_batch_size: int, num_contexts: int, data: List[RelevanceQuestion]) -> QuestionAndParagraphsSpec: max_ques_size = 0 max_word_size = 0 max_para_size = 0 max_num_contexts = num_contexts for data_point in data: contexts = data_point.distractors + data_point.supporting_facts # max_num_contexts = num_contexts max_word_size = max( max_word_size, max(len(word) for word in flatten_iterable(contexts))) max_para_size = max(max_para_size, max(len(context) for context in contexts)) max_ques_size = max(max_ques_size, len(data_point.question_tokens)) max_word_size = max( max_word_size, max(len(word) for word in data_point.question_tokens)) return QuestionAndParagraphsSpec(batch_size, max_num_contexts, max_ques_size, max_para_size, max_word_size, max_batch_size)
def _sample_false_2(self, qid): """ False sample of type 2: first distractor, second one of supporting facts """ question = self.qid2question[qid] rand_par_other_q = self._sample_rand_par_other_q(qid) distractor = self.random.choice( [self.random.choice(question.distractors), rand_par_other_q], p=[0.9, 0.1]) gold = self.random.choice(question.supporting_facts) pars = [flatten_iterable(x.sentences) for x in [distractor, gold]] segs = [ get_segments_from_sentences(x.sentences) for x in [distractor, gold] ] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=pars, first_label=0, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=segs)
def get_vocab(self): """ get all-lower cased unique words for this corpus, includes train/dev/test files """ if not exists(self.dir): self.dir = join(config.CORPUS_DIR, self.NAME) voc_file = join(self.dir, self.VOCAB_FILE) if exists(voc_file): with open(voc_file, "r") as f: return [x.rstrip() for x in f] else: voc = set() for fn in [self.get_train, self.get_dev, self.get_test]: for question in fn(): voc.update(x.lower() for x in question.question_tokens) for para in (question.distractors + question.supporting_facts): voc.update(x.lower() for x in flatten_iterable(para.sentences)) voc_list = sorted(list(voc)) with open(voc_file, "w") as f: for word in voc_list: f.write(word) f.write("\n") return voc_list
def _build_gold_samples(self): gold_samples = [] for question in self.questions: if question.q_type == 'comparison' or self.bridge_as_comparison: pars_order = [0, 1] self.random.shuffle(pars_order) else: if not self.label_by_span: pars_order = [0, 1] if question.gold_scores[ 0] > question.gold_scores[1] else [1, 0] else: gold_idxs = self._get_no_span_containing_golds( question.question_id) pars_order = [0, 1] if 0 in gold_idxs else [1, 0] if len( gold_idxs ) != 1: # either both contain the answer or both don't contain, so regarded equal self.random.shuffle(pars_order) pars = [ flatten_iterable(question.supporting_facts[i].sentences) for i in pars_order ] sentence_segs = [ get_segments_from_sentences( question.supporting_facts[i].sentences) for i in pars_order ] gold_samples.append( IterativeQuestionAndParagraphs( question.question_tokens, pars, first_label=1, second_label=1, question_id=question.question_id, q_type=question.q_type, sentence_segments=sentence_segs)) return gold_samples
def get_answers(question: str, top_title_tuples: List[Tuple[str]]): question_tok = tokenize_words(question) all_titles = list(set([title for titles in top_title_tuples for title in titles])) texts_tok = [x for x in tok_workers.imap(tokenize_sentences, [db.get_doc_sentences(title) for title in all_titles])] title2tok_sents = {title: tok_sents for title, tok_sents in zip(all_titles, texts_tok)} questions = [] for rank, title_tuple in enumerate(top_title_tuples): tokenized_sents = [sent for t in title_tuple for sent in title2tok_sents[t]] sentence_segments, _ = get_segments_from_sentences_fix_sup(tokenized_sents, np.zeros(0)) missing_sent_idx = [[i for i, sent in enumerate(title2tok_sents[title]) if len(sent) == 0] for title in title_tuple] questions.append(RankedQAPair(question=question_tok, paragraphs=[flatten_iterable(tokenized_sents)], spans=np.zeros((0, 2), dtype=np.int32), question_id='bla', answer='noanswer', rank=rank, q_type='n/a', sentence_segments=[sentence_segments], par_titles_num_sents= [(title, sum(1 for sent in title2tok_sents[title] if len(sent) > 0)) for title in title_tuple], missing_sent_idxs=missing_sent_idx, true_sp=[])) data = DummyDataset(questions, batcher) evaluation = evaluator_runner.run_evaluators(qa_sess, data, 'bla', None, {}, disable_tqdm=True) df = pd.DataFrame(evaluation.per_sample) df.sort_values(["question_id", "rank"], inplace=True, ascending=True) answer_dict, sp_dict = df_to_pred(df, None, return_results=True) # sp_raw = [db.get_doc_sentences(title)[idx] for title, idx in sp_dict['bla']] title2idxs = {} for title, idx in sp_dict['bla']: if title not in title2idxs: title2idxs[title] = [] title2idxs[title].append(idx) sp_titles_idxs = [(title, idxs) for title, idxs in title2idxs.items()] return answer_dict['bla'], sp_titles_idxs
def num_tokens(self): return len(flatten_iterable(self.sentences))
def run_evaluators(self, sess: tf.Session, dataset, name, n_sample, feed_dict, disable_tqdm=False) -> Evaluation: all_tensors_needed = list( set(flatten_iterable(x.values() for x in self.tensors_needed))) tensors = {x: [] for x in all_tensors_needed} data_used = [] if n_sample is None: batches, n_batches = dataset.get_epoch(), len(dataset) else: batches, n_batches = dataset.get_samples(n_sample) def enqueue_eval(): try: for data in batches: encoded = self.model.encode(data, False) data_used.append(data) sess.run(self.enqueue_op, encoded) except Exception as e: sess.run(self.close_queue) # Crash the main thread raise e # we should run out of batches and exit gracefully th = Thread(target=enqueue_eval) th.daemon = True th.start() for _ in tqdm(range(n_batches), total=n_batches, desc=name, ncols=80, disable=disable_tqdm): output = sess.run(all_tensors_needed, feed_dict=feed_dict) for i in range(len(all_tensors_needed)): tensors[all_tensors_needed[i]].append(output[i]) th.join() if sess.run(self.queue_size) != 0: raise RuntimeError("All batches should be been consumed") # flatten the input for k in all_tensors_needed: v = tensors[k] if len(k.shape) == 0: v = np.array(v) # List of scalars -> array elif any(x is None for x in k.shape.as_list()[1:]): # Variable sized tensors, so convert to flat python-list v = flatten_iterable(v) else: v = np.concatenate(v, axis=0) # concat along the batch dim tensors[k] = v # flatten the data if it consists of batches if isinstance(data_used[0], List): data_used = flatten_iterable(data_used) if dataset.percent_filtered() is None: true_len = len(data_used) else: true_len = len(data_used) * 1 / (1 - dataset.percent_filtered()) combined = None for ev, needed in zip(self.evaluators, self.tensors_needed): args = {k: tensors[v] for k, v in needed.items()} evaluation = ev.evaluate(data_used, true_len, **args) if combined is None: combined = evaluation else: combined.add(evaluation) return combined
def get_epoch(self, new_epoch=True): if self.fixed_dataset: new_epoch = False if not new_epoch and self.epoch_samples is not None: return self.batcher.get_epoch(self.epoch_samples) false_samples = [] for question in self.questions: two_distractors = [ flatten_iterable(x.sentences) for x in self.random.choice( question.distractors, size=2, replace=False) ] true_and_false_1 = [ flatten_iterable(question.supporting_facts[0].sentences), two_distractors[0] ] true_and_false_2 = [ flatten_iterable(question.supporting_facts[1].sentences), two_distractors[1] ] self.random.shuffle(true_and_false_1) self.random.shuffle(true_and_false_2) false_samples.append( BinaryQuestionAndParagraphs(question.question_tokens, true_and_false_1, 0, num_distractors=1, question_id=question.question_id, q_type=question.q_type)) false_samples.append( BinaryQuestionAndParagraphs(question.question_tokens, true_and_false_2, 0, num_distractors=1, question_id=question.question_id, q_type=question.q_type)) false_samples.append( BinaryQuestionAndParagraphs(question.question_tokens, [ flatten_iterable(x.sentences) for x in self.random.choice( question.distractors, size=2, replace=False) ], 0, num_distractors=2, question_id=question.question_id, q_type=question.q_type)) if self.add_gold_distractor: rand_q_idx = self.random.randint(len(self.gold_samples)) while self.gold_samples[ rand_q_idx].question_id == question.question_id: rand_q_idx = self.random.randint(len(self.gold_samples)) selected_q = self.gold_samples[rand_q_idx] self.random.shuffle(selected_q.paragraphs) false_samples.append( BinaryQuestionAndParagraphs( question.question_tokens, selected_q.paragraphs, label=0, num_distractors=2, question_id=question.question_id, q_type=question.q_type)) for gold in self.gold_samples: self.random.shuffle(gold.paragraphs) self.epoch_samples = self.gold_samples + false_samples np.random.shuffle(self.epoch_samples) return self.batcher.get_epoch(self.epoch_samples)
def truncate_paragraph(tokenized_sentences: List[List[str]], num_tokens): while len(flatten_iterable(tokenized_sentences)) > num_tokens: tokenized_sentences = tokenized_sentences[:-1] return tokenized_sentences
def main(): parser = argparse.ArgumentParser( description='Full ranking evaluation on Hotpot') parser.add_argument('model', help='model directory to evaluate') parser.add_argument( 'output', type=str, help="Store the per-paragraph results in csv format in this file, " "or the json prediction if in test mode") parser.add_argument('-n', '--sample_questions', type=int, default=None, help="(for testing) run on a subset of questions") parser.add_argument( '-b', '--batch_size', type=int, default=64, help="Batch size, larger sizes can be faster but uses more memory") parser.add_argument( '-s', '--step', default=None, help="Weights to load, can be a checkpoint step or 'latest'") parser.add_argument('-a', '--answer_bound', type=int, default=8, help="Max answer span length") parser.add_argument('-c', '--corpus', choices=[ "distractors", "gold", "hotpot_file", "retrieval_file", "top_titles" ], default="distractors") parser.add_argument('-t', '--tokens', type=int, default=None, help="Max tokens per a paragraph") parser.add_argument('--input_file', type=str, default=None) parser.add_argument('--docs_file', type=str, default=None) parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for tokenizing') parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist") parser.add_argument('--no_sp', action="store_true", help="Don't predict supporting facts") parser.add_argument('--test_mode', action='store_true', help="produce a prediction file, no answers given") args = parser.parse_args() model_dir = ModelDir(args.model) batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True) loader = ResourceLoader() if args.corpus not in {"distractors", "gold"} and args.input_file is None: raise ValueError( "Must pass an input file if not using precomputed dataset") if args.corpus in {"distractors", "gold"} and args.test_mode: raise ValueError( "Test mode not available in 'distractors' or 'gold' mode") if args.corpus in {"distractors", "gold"}: corpus = HotpotQuestions() loader = corpus.get_resource_loader() questions = corpus.get_dev() question_preprocessor = HotpotTextLengthPreprocessorWithSpans( args.tokens) questions = [ question_preprocessor.preprocess(x) for x in questions if (question_preprocessor.preprocess(x) is not None) ] if args.sample_questions: np.random.RandomState(0).shuffle( sorted(questions, key=lambda x: x.question_id)) questions = questions[:args.sample_questions] data = HotpotFullQADistractorsDataset(questions, batcher) gold_idxs = set(data.gold_idxs) if args.corpus == 'gold': data.samples = [data.samples[i] for i in data.gold_idxs] qid2samples = {} qid2idx = {} for i, sample in enumerate(data.samples): key = sample.question_id if key in qid2samples: qid2samples[key].append(sample) qid2idx[key].append(i) else: qid2samples[key] = [sample] qid2idx[key] = [i] questions = [] print("Ranking pairs...") gold_ranks = [] for qid, samples in tqdm(qid2samples.items()): question = " ".join(samples[0].question) pars = [" ".join(x.paragraphs[0]) for x in samples] ranks = get_paragraph_ranks(question, pars) for sample, rank, idx in zip(samples, ranks, qid2idx[qid]): questions.append( RankedQAPair(question=sample.question, paragraphs=sample.paragraphs, spans=np.zeros((0, 2), dtype=np.int32), question_id=sample.question_id, answer=sample.answer, rank=rank, q_type=sample.q_type, sentence_segments=sample.sentence_segments)) if idx in gold_idxs: gold_ranks.append(rank + 1) print(f"Mean rank: {np.mean(gold_ranks)}") ranks_counter = Counter(gold_ranks) for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]: print(f"Hits at {i}: {ranks_counter[i]}") elif args.corpus == 'hotpot_file': # a hotpot json format input file. We rank the pairs with tf-idf with open(args.input_file, 'r') as f: hotpot_data = json.load(f) if args.sample_questions: np.random.RandomState(0).shuffle( sorted(hotpot_data, key=lambda x: x['_id'])) hotpot_data = hotpot_data[:args.sample_questions] title2sentences = { context[0]: context[1] for q in hotpot_data for context in q['context'] } question_tok_texts = tokenize_texts( [q['question'] for q in hotpot_data], num_workers=args.num_workers) sentences_tok = tokenize_texts(list(title2sentences.values()), num_workers=args.num_workers, sentences=True) if args.tokens is not None: sentences_tok = [ truncate_paragraph(p, args.tokens) for p in sentences_tok ] title2tok_sents = { title: sentences for title, sentences in zip(title2sentences.keys(), sentences_tok) } questions = [] for idx, question in enumerate(tqdm(hotpot_data, desc='tf-idf ranking')): q_titles = [title for title, _ in question['context']] par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles) for title2 in q_titles[i + 1:]] if len(par_pairs) == 0: continue ranks = get_paragraph_ranks(question['question'], [ ' '.join(title2sentences[t1] + title2sentences[t2]) for t1, t2 in par_pairs ]) for rank, par_pair in zip(ranks, par_pairs): sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[ par_pair[1]] sentence_segments, _ = get_segments_from_sentences_fix_sup( sent_tok_pair, np.zeros(0)) missing_sent_idx = [[ i for i, sent in enumerate(title2tok_sents[title]) if len(sent) == 0 ] for title in par_pair] questions.append( RankedQAPair( question=question_tok_texts[idx], paragraphs=[flatten_iterable(sent_tok_pair)], spans=np.zeros((0, 2), dtype=np.int32), question_id=question['_id'], answer='noanswer' if args.test_mode else question['answer'], rank=rank, q_type='null' if args.test_mode else question['type'], sentence_segments=[sentence_segments], par_titles_num_sents=[ (title, sum(1 for sent in title2tok_sents[title] if len(sent) > 0)) for title in par_pair ], missing_sent_idxs=missing_sent_idx, true_sp=[] if args.test_mode else question['supporting_facts'])) elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles': if args.docs_file is None: print("Using DB documents") doc_db = DocDB(config.DOC_DB, full_docs=False) else: with open(args.docs_file, 'r') as f: docs = json.load(f) with open(args.input_file, 'r') as f: retrieval_data = json.load(f) if args.sample_questions: np.random.RandomState(0).shuffle( sorted(retrieval_data, key=lambda x: x['qid'])) retrieval_data = retrieval_data[:args.sample_questions] def parname_to_text(par_name): par_title = par_name_to_title(par_name) par_num = int(par_name.split('_')[-1]) if args.docs_file is None: return doc_db.get_doc_sentences(par_title) return docs[par_title][par_num] if args.corpus == 'top_titles': print("Top TF-IDF!") for q in retrieval_data: top_titles = q['top_titles'][:10] q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0') for i, title1 in enumerate(top_titles) for title2 in top_titles[i + 1:]] question_tok_texts = tokenize_texts( [q['question'] for q in retrieval_data], num_workers=args.num_workers) all_parnames = list( set([ parname for q in retrieval_data for pair in q['paragraph_pairs'] for parname in pair ])) texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames], num_workers=args.num_workers, sentences=True) if args.tokens is not None: texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok] parname2tok_text = { parname: text for parname, text in zip(all_parnames, texts_tok) } questions = [] for idx, question in enumerate(retrieval_data): for rank, par_pair in enumerate(question['paragraph_pairs']): tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[ par_pair[1]] sentence_segments, _ = get_segments_from_sentences_fix_sup( tok_pair, np.zeros(0)) missing_sent_idx = [[ i for i, sent in enumerate(parname2tok_text[parname]) if len(sent) == 0 ] for parname in par_pair] questions.append( RankedQAPair( question=question_tok_texts[idx], paragraphs=[flatten_iterable(tok_pair)], spans=np.zeros((0, 2), dtype=np.int32), question_id=question['qid'], answer='noanswer' if args.test_mode else question['answers'][0], rank=rank, q_type='null' if args.test_mode else question['type'], sentence_segments=[sentence_segments], par_titles_num_sents=[ (par_name_to_title(parname), sum(1 for sent in parname2tok_text[parname] if len(sent) > 0)) for parname in par_pair ], missing_sent_idxs=missing_sent_idx, true_sp=[] if args.test_mode else question['supporting_facts'])) else: raise NotImplementedError() data = DummyDataset(questions, batcher) evaluators = [ RecordHotpotQAPrediction(args.answer_bound, True, sp_prediction=not args.no_sp) ] if args.step is not None: if args.step == "latest": checkpoint = model_dir.get_latest_checkpoint() else: checkpoint = model_dir.get_checkpoint(int(args.step)) else: checkpoint = model_dir.get_best_weights() if checkpoint is not None: print("Using best weights") else: print("Using latest checkpoint") checkpoint = model_dir.get_latest_checkpoint() model = model_dir.get_model() evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader, checkpoint, not args.no_ema, 10)[args.corpus] print("Saving result") output_file = args.output df = pd.DataFrame(evaluation.per_sample) df.sort_values(["question_id", "rank"], inplace=True, ascending=True) group_by = ["question_id"] def get_ranked_scores(score_name): filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \ df[df.type == 'bridge'] if "Br" in score_name else df target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text' target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}" return compute_ranked_scores_with_yes_no( filtered_df, span_q_col="span_question_scores", yes_no_q_col="yes_no_question_scores", yes_no_scores_col="yes_no_confidence_scores", span_scores_col="predicted_score", span_target_score=target_score, group_cols=group_by) if not args.test_mode: score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"] if not args.no_sp: score_names.extend([ f"{prefix} {name}" for prefix in ['sp', 'joint'] for name in score_names ]) table = [["N Paragraphs"] + score_names] scores = [get_ranked_scores(score_name) for score_name in score_names] table += list([str(i + 1), *["%.4f" % x for x in score_vals]] for i, score_vals in enumerate(zip(*scores))) print_table(table) df.to_csv(output_file, index=False) else: df_to_pred(df, output_file)