def show_web_paragraphs(): splitter = MergeParagraphs(400) stop = NltkPlusStopWords(True) ranker = TopTfIdf(stop, 6) stop_words = stop.words corpus = TriviaQaWebDataset() train = corpus.get_train() points = flatten_iterable([(q, d) for d in q.all_docs] for q in train) np.random.shuffle(points) for q, d in points: q_words = {strip_accents_unicode(w.lower()) for w in q.question} q_words = {x for x in q_words if x not in stop_words} doc = corpus.evidence.get_document(d.doc_id) doc = splitter.split_annotated(doc, d.answer_spans) ranked = ranker.dists(q.question, doc) if len(ranked) < 2 or len(ranked[1][0].answer_spans) == 0: continue print(" ".join(q.question)) print(q.answer.all_answers) for i, (para, dist) in enumerate(ranked[0:2]): text = flatten_iterable(para.text) print("Start=%d, Rank=%d, Dist=%.4f" % (para.start, i, dist)) if len(para.answer_spans) == 0: continue for s, e in para.answer_spans: text[s] = bcolors.CYAN + text[s] text[e] = text[e] + bcolors.ENDC for i, w in enumerate(text): if strip_accents_unicode(w.lower()) in q_words: text[i] = bcolors.ERROR + text[i] + bcolors.ENDC print(" ".join(text)) input()
def main(): data = TriviaQaWebDataset() stop = NltkPlusStopWords() splitter = MergeParagraphs(400) selector = TopTfIdf(stop, 4) print("Loading data..") train = data.get_train() print("Start") for q in train: for doc in q.all_docs: if len(doc.answer_spans) > 3: text = splitter.split_annotated( data.evidence.get_document(doc.doc_id), doc.answer_spans) text = selector.prune(q.question, text) for para in text: if len(para.answer_spans) > 3: print(q.question) text = flatten_iterable(para.text) for s, e in para.answer_spans: text[s] = "{{{" + text[s] text[e] = text[e] + "}}}" print(" ".join(text)) input()
def check_preprocess(): data = TriviaQaWebDataset() merge = MergeParagraphs(400) questions = data.get_dev() pre = WithIndicators(False) remove_cross = WithIndicators(True) rng = np.random.RandomState(0) rng.shuffle(questions) for q in tqdm(questions[:1000]): doc = rng.choice(q.all_docs, 1)[0] text = data.evidence.get_document(doc.doc_id, n_tokens=800) paras = merge.split_annotated(text, doc.answer_spans) para = paras[np.random.randint(0, len(paras))] built = pre.encode_extracted_paragraph(q.question, para) expected_text = flatten_iterable(para.text) if expected_text != [ x for x in built.text if x not in pre.special_tokens() ]: raise ValueError() expected = [expected_text[s:e + 1] for s, e in para.answer_spans] expected = Counter([tuple(x) for x in expected]) actual = [tuple(built.text[s:e + 1]) for s, e in built.answer_spans] actual_cleaned = Counter( tuple(z for z in x if z not in pre.special_tokens()) for x in actual) if actual_cleaned != expected: raise ValueError() r_built = remove_cross.encode_extracted_paragraph(q.question, para) rc = Counter( tuple(r_built.text[s:e + 1]) for s, e in r_built.answer_spans) removed = Counter() for w in actual: if all(x not in pre.special_tokens() for x in w): removed[w] += 1 if rc != removed: raise ValueError()
def show_open_paragraphs(start: int, end: int): splitter = MergeParagraphs(400) stop = NltkPlusStopWords(True) ranker = ShallowOpenWebRanker(6) stop_words = stop.words print("Loading train") corpus = TriviaQaOpenDataset() train = corpus.get_dev() np.random.shuffle(train) for q in train: q_words = {strip_accents_unicode(w.lower()) for w in q.question} q_words = {x for x in q_words if x not in stop_words} para = [] for d in q.all_docs: doc = corpus.evidence.get_document(d.doc_id) para += splitter.split_annotated(doc, d.answer_spans) ranked = ranker.prune(q.question, para) if len(ranked) < start: continue ranked = ranked[start:end] print(" ".join(q.question)) print(q.answer.all_answers) for i in range(start, end): para = ranked[i] text = flatten_iterable(para.text) print("Start=%d, Rank=%d" % (para.start, i)) if len(para.answer_spans) == 0: # print("No Answer!") continue for s, e in para.answer_spans: text[s] = bcolors.CYAN + text[s] text[e] = text[e] + bcolors.ENDC for i, w in enumerate(text): if strip_accents_unicode(w.lower()) in q_words: text[i] = bcolors.ERROR + text[i] + bcolors.ENDC print(" ".join(text)) input()
def show_stats(): splitter = MergeParagraphs(400) stop = NltkPlusStopWords(True) ranker = TopTfIdf(stop, 6) corpus = TriviaQaWebDataset() train = corpus.get_train() points = flatten_iterable([(q, d) for d in q.all_docs] for q in train) np.random.shuffle(points) counts = np.zeros(6) answers = np.zeros(6) n_answers = [] points = points[:1000] for q, d in tqdm(points): doc = corpus.evidence.get_document(d.doc_id) doc = splitter.split_annotated(doc, d.answer_spans) ranked = ranker.prune(q.question, doc) counts[:len(ranked)] += 1 for i, para in enumerate(ranked): if len(para.answer_spans) > 0: answers[i] += 1 n_answers.append( tuple(i for i, x in enumerate(ranked) if len(x.answer_spans) > 0)) print(answers / counts) c = Counter() other = 0 for tup in n_answers: if len(tup) <= 2: c[tup] += 1 else: other += 1 for p in sorted(c.keys()): print(p, c.get(p) / len(points)) print(other / len(points))
def contains_question_word(): data = TriviaQaWebDataset() stop = NltkPlusStopWords(punctuation=True).words doc_filter = ContainsQuestionWord(NltkPlusStopWords(punctuation=True)) splits = MergeParagraphs(400) # splits = Truncate(400) questions = data.get_dev() pairs = flatten_iterable([(q, doc) for doc in q.all_docs] for q in questions) pairs.sort(key=lambda x: (x[0].question_id, x[1].doc_id)) np.random.RandomState(0).shuffle(questions) has_token = 0 total = 0 used = Counter() for q, doc in tqdm(pairs[:1000]): text = data.evidence.get_document(doc.doc_id, splits.reads_first_n) q_tokens = set(x.lower() for x in q.question) q_tokens -= stop for para in splits.split_annotated(text, doc.answer_spans): # if para.start == 0: # continue if len(para.answer_spans) == 0: continue if any(x.lower() in q_tokens for x in flatten_iterable(para.text)): has_token += 1 for x in flatten_iterable(para.text): if x in q_tokens: used[x] += 1 # else: # print_questions(q.question, q.answer.all_answers, para.text, para.answer_spans) # input() total += 1 for k, v in used.most_common(200): print("%s: %d" % (k, v)) print(has_token / total)