def show_web_paragraphs():
    splitter = MergeParagraphs(400)
    stop = NltkPlusStopWords(True)
    ranker = TopTfIdf(stop, 6)
    stop_words = stop.words

    corpus = TriviaQaWebDataset()
    train = corpus.get_train()
    points = flatten_iterable([(q, d) for d in q.all_docs] for q in train)
    np.random.shuffle(points)

    for q, d in points:
        q_words = {strip_accents_unicode(w.lower()) for w in q.question}
        q_words = {x for x in q_words if x not in stop_words}

        doc = corpus.evidence.get_document(d.doc_id)
        doc = splitter.split_annotated(doc, d.answer_spans)
        ranked = ranker.dists(q.question, doc)
        if len(ranked) < 2 or len(ranked[1][0].answer_spans) == 0:
            continue
        print(" ".join(q.question))
        print(q.answer.all_answers)
        for i, (para, dist) in enumerate(ranked[0:2]):
            text = flatten_iterable(para.text)
            print("Start=%d, Rank=%d, Dist=%.4f" % (para.start, i, dist))
            if len(para.answer_spans) == 0:
                continue
            for s, e in para.answer_spans:
                text[s] = bcolors.CYAN + text[s]
                text[e] = text[e] + bcolors.ENDC
            for i, w in enumerate(text):
                if strip_accents_unicode(w.lower()) in q_words:
                    text[i] = bcolors.ERROR + text[i] + bcolors.ENDC
            print(" ".join(text))
        input()
Esempio n. 2
0
def main():
    data = TriviaQaWebDataset()

    stop = NltkPlusStopWords()
    splitter = MergeParagraphs(400)
    selector = TopTfIdf(stop, 4)

    print("Loading data..")
    train = data.get_train()
    print("Start")
    for q in train:
        for doc in q.all_docs:
            if len(doc.answer_spans) > 3:
                text = splitter.split_annotated(
                    data.evidence.get_document(doc.doc_id), doc.answer_spans)
                text = selector.prune(q.question, text)
                for para in text:
                    if len(para.answer_spans) > 3:
                        print(q.question)
                        text = flatten_iterable(para.text)
                        for s, e in para.answer_spans:
                            text[s] = "{{{" + text[s]
                            text[e] = text[e] + "}}}"
                        print(" ".join(text))
                        input()
def check_preprocess():
    data = TriviaQaWebDataset()
    merge = MergeParagraphs(400)
    questions = data.get_dev()
    pre = WithIndicators(False)
    remove_cross = WithIndicators(True)
    rng = np.random.RandomState(0)
    rng.shuffle(questions)

    for q in tqdm(questions[:1000]):
        doc = rng.choice(q.all_docs, 1)[0]
        text = data.evidence.get_document(doc.doc_id, n_tokens=800)
        paras = merge.split_annotated(text, doc.answer_spans)
        para = paras[np.random.randint(0, len(paras))]
        built = pre.encode_extracted_paragraph(q.question, para)

        expected_text = flatten_iterable(para.text)
        if expected_text != [
                x for x in built.text if x not in pre.special_tokens()
        ]:
            raise ValueError()

        expected = [expected_text[s:e + 1] for s, e in para.answer_spans]
        expected = Counter([tuple(x) for x in expected])

        actual = [tuple(built.text[s:e + 1]) for s, e in built.answer_spans]
        actual_cleaned = Counter(
            tuple(z for z in x if z not in pre.special_tokens())
            for x in actual)
        if actual_cleaned != expected:
            raise ValueError()

        r_built = remove_cross.encode_extracted_paragraph(q.question, para)
        rc = Counter(
            tuple(r_built.text[s:e + 1]) for s, e in r_built.answer_spans)
        removed = Counter()
        for w in actual:
            if all(x not in pre.special_tokens() for x in w):
                removed[w] += 1

        if rc != removed:
            raise ValueError()
def show_open_paragraphs(start: int, end: int):
    splitter = MergeParagraphs(400)
    stop = NltkPlusStopWords(True)
    ranker = ShallowOpenWebRanker(6)
    stop_words = stop.words

    print("Loading train")
    corpus = TriviaQaOpenDataset()
    train = corpus.get_dev()
    np.random.shuffle(train)

    for q in train:
        q_words = {strip_accents_unicode(w.lower()) for w in q.question}
        q_words = {x for x in q_words if x not in stop_words}

        para = []
        for d in q.all_docs:
            doc = corpus.evidence.get_document(d.doc_id)
            para += splitter.split_annotated(doc, d.answer_spans)

        ranked = ranker.prune(q.question, para)
        if len(ranked) < start:
            continue
        ranked = ranked[start:end]

        print(" ".join(q.question))
        print(q.answer.all_answers)
        for i in range(start, end):
            para = ranked[i]
            text = flatten_iterable(para.text)
            print("Start=%d, Rank=%d" % (para.start, i))
            if len(para.answer_spans) == 0:
                # print("No Answer!")
                continue
            for s, e in para.answer_spans:
                text[s] = bcolors.CYAN + text[s]
                text[e] = text[e] + bcolors.ENDC
            for i, w in enumerate(text):
                if strip_accents_unicode(w.lower()) in q_words:
                    text[i] = bcolors.ERROR + text[i] + bcolors.ENDC
            print(" ".join(text))
        input()
def show_stats():
    splitter = MergeParagraphs(400)
    stop = NltkPlusStopWords(True)
    ranker = TopTfIdf(stop, 6)

    corpus = TriviaQaWebDataset()
    train = corpus.get_train()
    points = flatten_iterable([(q, d) for d in q.all_docs] for q in train)
    np.random.shuffle(points)

    counts = np.zeros(6)
    answers = np.zeros(6)
    n_answers = []

    points = points[:1000]
    for q, d in tqdm(points):
        doc = corpus.evidence.get_document(d.doc_id)
        doc = splitter.split_annotated(doc, d.answer_spans)
        ranked = ranker.prune(q.question, doc)
        counts[:len(ranked)] += 1
        for i, para in enumerate(ranked):
            if len(para.answer_spans) > 0:
                answers[i] += 1
        n_answers.append(
            tuple(i for i, x in enumerate(ranked) if len(x.answer_spans) > 0))

    print(answers / counts)
    c = Counter()
    other = 0
    for tup in n_answers:
        if len(tup) <= 2:
            c[tup] += 1
        else:
            other += 1

    for p in sorted(c.keys()):
        print(p, c.get(p) / len(points))
    print(other / len(points))
Esempio n. 6
0
def contains_question_word():
    data = TriviaQaWebDataset()
    stop = NltkPlusStopWords(punctuation=True).words
    doc_filter = ContainsQuestionWord(NltkPlusStopWords(punctuation=True))
    splits = MergeParagraphs(400)
    # splits = Truncate(400)
    questions = data.get_dev()
    pairs = flatten_iterable([(q, doc) for doc in q.all_docs]
                             for q in questions)
    pairs.sort(key=lambda x: (x[0].question_id, x[1].doc_id))
    np.random.RandomState(0).shuffle(questions)
    has_token = 0
    total = 0
    used = Counter()

    for q, doc in tqdm(pairs[:1000]):
        text = data.evidence.get_document(doc.doc_id, splits.reads_first_n)
        q_tokens = set(x.lower() for x in q.question)
        q_tokens -= stop
        for para in splits.split_annotated(text, doc.answer_spans):
            # if para.start == 0:
            #     continue
            if len(para.answer_spans) == 0:
                continue
            if any(x.lower() in q_tokens for x in flatten_iterable(para.text)):
                has_token += 1
                for x in flatten_iterable(para.text):
                    if x in q_tokens:
                        used[x] += 1
            # else:
            #     print_questions(q.question, q.answer.all_answers, para.text, para.answer_spans)
            #     input()
            total += 1
    for k, v in used.most_common(200):
        print("%s: %d" % (k, v))
    print(has_token / total)