Beispiel #1
0
def test_splitter(splitter: DocumentSplitter,
                  n_sample,
                  n_answer_spans,
                  seed=None):
    rng = np.random.RandomState(seed)
    corpus = TriviaQaEvidenceCorpusTxt()
    docs = sorted(corpus.list_documents())
    rng.shuffle(docs)
    max_tokens = splitter.max_tokens
    read_n = splitter.reads_first_n
    for doc in docs[:n_sample]:
        text = corpus.get_document(doc, read_n)
        fake_answers = []
        offset = 0
        for para in text:
            flattened = flatten_iterable(para)
            fake_answer_starts = np.random.choice(len(flattened),
                                                  min(
                                                      len(flattened) // 2,
                                                      np.random.randint(5)),
                                                  replace=False)
            max_answer_lens = np.minimum(
                len(flattened) - fake_answer_starts, 30)
            fake_answer_ends = fake_answer_starts + np.floor(
                rng.uniform() * max_answer_lens).astype(np.int32)
            fake_answers.append(
                np.concatenate([
                    np.expand_dims(fake_answer_starts, 1),
                    np.expand_dims(fake_answer_ends, 1)
                ],
                               axis=1) + offset)
            offset += len(flattened)

        fake_answers = np.concatenate(fake_answers, axis=0)
        flattened = flatten_iterable(flatten_iterable(text))
        answer_strs = set(tuple(flattened[s:e + 1]) for s, e in fake_answers)

        paragraphs = splitter.split_annotated(text, fake_answers)

        for para in paragraphs:
            text = flatten_iterable(para.text)
            if max_tokens is not None and len(text) > max_tokens:
                raise ValueError(
                    "Paragraph len len %d, but max tokens was %d" %
                    (len(text), max_tokens))
            start, end = para.start, para.end
            if text != flattened[start:end]:
                raise ValueError(
                    "Paragraph is missing text, given bounds were %d-%d" %
                    (start, end))
            for s, e in para.answer_spans:
                if tuple(text[s:e + 1]) not in answer_strs:
                    print(s, e)
                    raise ValueError(
                        "Incorrect answer for paragraph %d-%d (%s)" %
                        (start, end, " ".join(text[s:e + 1])))
Beispiel #2
0
def show_paragraph_lengths():
    corpus = TriviaQaEvidenceCorpusTxt()
    docs = corpus.list_documents()
    np.random.shuffle(docs)
    para_lens = []
    for doc in docs[:5000]:
        text = corpus.get_document(doc)
        para_lens += [sum(len(s) for s in x) for x in text]
    para_lens = np.array(para_lens)
    for i in [400, 500, 600, 700, 800]:
        print("Over %s: %.4f" % (i, (para_lens > i).sum()/len(para_lens)))