def test_splitter(splitter: DocumentSplitter, n_sample, n_answer_spans, seed=None): rng = np.random.RandomState(seed) corpus = TriviaQaEvidenceCorpusTxt() docs = sorted(corpus.list_documents()) rng.shuffle(docs) max_tokens = splitter.max_tokens read_n = splitter.reads_first_n for doc in docs[:n_sample]: text = corpus.get_document(doc, read_n) fake_answers = [] offset = 0 for para in text: flattened = flatten_iterable(para) fake_answer_starts = np.random.choice(len(flattened), min( len(flattened) // 2, np.random.randint(5)), replace=False) max_answer_lens = np.minimum( len(flattened) - fake_answer_starts, 30) fake_answer_ends = fake_answer_starts + np.floor( rng.uniform() * max_answer_lens).astype(np.int32) fake_answers.append( np.concatenate([ np.expand_dims(fake_answer_starts, 1), np.expand_dims(fake_answer_ends, 1) ], axis=1) + offset) offset += len(flattened) fake_answers = np.concatenate(fake_answers, axis=0) flattened = flatten_iterable(flatten_iterable(text)) answer_strs = set(tuple(flattened[s:e + 1]) for s, e in fake_answers) paragraphs = splitter.split_annotated(text, fake_answers) for para in paragraphs: text = flatten_iterable(para.text) if max_tokens is not None and len(text) > max_tokens: raise ValueError( "Paragraph len len %d, but max tokens was %d" % (len(text), max_tokens)) start, end = para.start, para.end if text != flattened[start:end]: raise ValueError( "Paragraph is missing text, given bounds were %d-%d" % (start, end)) for s, e in para.answer_spans: if tuple(text[s:e + 1]) not in answer_strs: print(s, e) raise ValueError( "Incorrect answer for paragraph %d-%d (%s)" % (start, end, " ".join(text[s:e + 1])))
def show_paragraph_lengths(): corpus = TriviaQaEvidenceCorpusTxt() docs = corpus.list_documents() np.random.shuffle(docs) para_lens = [] for doc in docs[:5000]: text = corpus.get_document(doc) para_lens += [sum(len(s) for s in x) for x in text] para_lens = np.array(para_lens) for i in [400, 500, 600, 700, 800]: print("Over %s: %.4f" % (i, (para_lens > i).sum()/len(para_lens)))
def build_dataset(name: str, tokenizer, train_files: Dict[str, str], answer_detector, n_process: int, prune_unmapped_docs=True, sample=None): out_dir = join(CORPUS_DIR, "triviaqa", name) if not exists(out_dir): mkdir(out_dir) file_map = {} # maps document_id -> filename for name, filename in train_files.items(): print("Loading %s questions" % name) if sample is None: questions = list(iter_trivia_question(filename, file_map, False)) else: if isinstance(sample, int): questions = list( islice(iter_trivia_question(filename, file_map, False), sample)) elif isinstance(sample, dict): questions = list( islice(iter_trivia_question(filename, file_map, False), sample[name])) else: raise ValueError() if prune_unmapped_docs: for q in questions: if q.web_docs is not None: q.web_docs = [ x for x in q.web_docs if x.doc_id in file_map ] q.entity_docs = [ x for x in q.entity_docs if x.doc_id in file_map ] print("Adding answers for %s question" % name) corpus = TriviaQaEvidenceCorpusTxt(file_map) questions = compute_answer_spans_par(questions, corpus, tokenizer, answer_detector, n_process) for q in questions: # Sanity check, we should have answers for everything (even if of size 0) if q.answer is None: continue for doc in q.all_docs: if doc.doc_id in file_map: if doc.answer_spans is None: raise RuntimeError() print("Saving %s question" % name) with open(join(out_dir, name + ".pkl"), "wb") as f: pickle.dump(questions, f) print("Dumping file mapping") with open(join(out_dir, "file_map.json"), "w") as f: json.dump(file_map, f) print("Complete")
def __init__(self, corpus_name): self.corpus_name = corpus_name self.dir = join(CORPUS_DIR, "triviaqa", corpus_name) with open(join(self.dir, "file_map.json"), "r") as f: file_map = json.load(f) for k, v in file_map.items(): file_map[k] = unicodedata.normalize("NFD", v) self.evidence = TriviaQaEvidenceCorpusTxt(file_map)
def process(x, verbose=False): dataset, filemap, max_num_support, max_tokens, is_web = x instances = [] corpus = TriviaQaEvidenceCorpusTxt(filemap) for i, q in enumerate(dataset): if verbose and i % 1000 == 0: print("%d/%d done" % (i, len(dataset))) instances.extend(x for x in convert_triviaqa( q, corpus, max_num_support, max_tokens, is_web)) return instances