def test_splitter(splitter: DocumentSplitter, n_sample, n_answer_spans, seed=None): rng = np.random.RandomState(seed) corpus = TriviaQaEvidenceCorpusTxt() docs = sorted(corpus.list_documents()) rng.shuffle(docs) max_tokens = splitter.max_tokens read_n = splitter.reads_first_n for doc in docs[:n_sample]: text = corpus.get_document(doc, read_n) fake_answers = [] offset = 0 for para in text: flattened = flatten_iterable(para) fake_answer_starts = np.random.choice(len(flattened), min( len(flattened) // 2, np.random.randint(5)), replace=False) max_answer_lens = np.minimum( len(flattened) - fake_answer_starts, 30) fake_answer_ends = fake_answer_starts + np.floor( rng.uniform() * max_answer_lens).astype(np.int32) fake_answers.append( np.concatenate([ np.expand_dims(fake_answer_starts, 1), np.expand_dims(fake_answer_ends, 1) ], axis=1) + offset) offset += len(flattened) fake_answers = np.concatenate(fake_answers, axis=0) flattened = flatten_iterable(flatten_iterable(text)) answer_strs = set(tuple(flattened[s:e + 1]) for s, e in fake_answers) paragraphs = splitter.split_annotated(text, fake_answers) for para in paragraphs: text = flatten_iterable(para.text) if max_tokens is not None and len(text) > max_tokens: raise ValueError( "Paragraph len len %d, but max tokens was %d" % (len(text), max_tokens)) start, end = para.start, para.end if text != flattened[start:end]: raise ValueError( "Paragraph is missing text, given bounds were %d-%d" % (start, end)) for s, e in para.answer_spans: if tuple(text[s:e + 1]) not in answer_strs: print(s, e) raise ValueError( "Incorrect answer for paragraph %d-%d (%s)" % (start, end, " ".join(text[s:e + 1])))
def show_paragraph_lengths(): corpus = TriviaQaEvidenceCorpusTxt() docs = corpus.list_documents() np.random.shuffle(docs) para_lens = [] for doc in docs[:5000]: text = corpus.get_document(doc) para_lens += [sum(len(s) for s in x) for x in text] para_lens = np.array(para_lens) for i in [400, 500, 600, 700, 800]: print("Over %s: %.4f" % (i, (para_lens > i).sum()/len(para_lens)))