예제 #1
0
def test_splitter(splitter: DocumentSplitter,
                  n_sample,
                  n_answer_spans,
                  seed=None):
    rng = np.random.RandomState(seed)
    corpus = TriviaQaEvidenceCorpusTxt()
    docs = sorted(corpus.list_documents())
    rng.shuffle(docs)
    max_tokens = splitter.max_tokens
    read_n = splitter.reads_first_n
    for doc in docs[:n_sample]:
        text = corpus.get_document(doc, read_n)
        fake_answers = []
        offset = 0
        for para in text:
            flattened = flatten_iterable(para)
            fake_answer_starts = np.random.choice(len(flattened),
                                                  min(
                                                      len(flattened) // 2,
                                                      np.random.randint(5)),
                                                  replace=False)
            max_answer_lens = np.minimum(
                len(flattened) - fake_answer_starts, 30)
            fake_answer_ends = fake_answer_starts + np.floor(
                rng.uniform() * max_answer_lens).astype(np.int32)
            fake_answers.append(
                np.concatenate([
                    np.expand_dims(fake_answer_starts, 1),
                    np.expand_dims(fake_answer_ends, 1)
                ],
                               axis=1) + offset)
            offset += len(flattened)

        fake_answers = np.concatenate(fake_answers, axis=0)
        flattened = flatten_iterable(flatten_iterable(text))
        answer_strs = set(tuple(flattened[s:e + 1]) for s, e in fake_answers)

        paragraphs = splitter.split_annotated(text, fake_answers)

        for para in paragraphs:
            text = flatten_iterable(para.text)
            if max_tokens is not None and len(text) > max_tokens:
                raise ValueError(
                    "Paragraph len len %d, but max tokens was %d" %
                    (len(text), max_tokens))
            start, end = para.start, para.end
            if text != flattened[start:end]:
                raise ValueError(
                    "Paragraph is missing text, given bounds were %d-%d" %
                    (start, end))
            for s, e in para.answer_spans:
                if tuple(text[s:e + 1]) not in answer_strs:
                    print(s, e)
                    raise ValueError(
                        "Incorrect answer for paragraph %d-%d (%s)" %
                        (start, end, " ".join(text[s:e + 1])))
예제 #2
0
def show_paragraph_lengths():
    corpus = TriviaQaEvidenceCorpusTxt()
    docs = corpus.list_documents()
    np.random.shuffle(docs)
    para_lens = []
    for doc in docs[:5000]:
        text = corpus.get_document(doc)
        para_lens += [sum(len(s) for s in x) for x in text]
    para_lens = np.array(para_lens)
    for i in [400, 500, 600, 700, 800]:
        print("Over %s: %.4f" % (i, (para_lens > i).sum()/len(para_lens)))
예제 #3
0
def build_dataset(name: str,
                  tokenizer,
                  train_files: Dict[str, str],
                  answer_detector,
                  n_process: int,
                  prune_unmapped_docs=True,
                  sample=None):
    out_dir = join(CORPUS_DIR, "triviaqa", name)
    if not exists(out_dir):
        mkdir(out_dir)

    file_map = {}  # maps document_id -> filename

    for name, filename in train_files.items():
        print("Loading %s questions" % name)
        if sample is None:
            questions = list(iter_trivia_question(filename, file_map, False))
        else:
            if isinstance(sample, int):
                questions = list(
                    islice(iter_trivia_question(filename, file_map, False),
                           sample))
            elif isinstance(sample, dict):
                questions = list(
                    islice(iter_trivia_question(filename, file_map, False),
                           sample[name]))
            else:
                raise ValueError()

        if prune_unmapped_docs:
            for q in questions:
                if q.web_docs is not None:
                    q.web_docs = [
                        x for x in q.web_docs if x.doc_id in file_map
                    ]
                q.entity_docs = [
                    x for x in q.entity_docs if x.doc_id in file_map
                ]

        print("Adding answers for %s question" % name)
        corpus = TriviaQaEvidenceCorpusTxt(file_map)
        questions = compute_answer_spans_par(questions, corpus, tokenizer,
                                             answer_detector, n_process)
        for q in questions:  # Sanity check, we should have answers for everything (even if of size 0)
            if q.answer is None:
                continue
            for doc in q.all_docs:
                if doc.doc_id in file_map:
                    if doc.answer_spans is None:
                        raise RuntimeError()

        print("Saving %s question" % name)
        with open(join(out_dir, name + ".pkl"), "wb") as f:
            pickle.dump(questions, f)

    print("Dumping file mapping")
    with open(join(out_dir, "file_map.json"), "w") as f:
        json.dump(file_map, f)

    print("Complete")
예제 #4
0
 def __init__(self, corpus_name):
     self.corpus_name = corpus_name
     self.dir = join(CORPUS_DIR, "triviaqa", corpus_name)
     with open(join(self.dir, "file_map.json"), "r") as f:
         file_map = json.load(f)
     for k, v in file_map.items():
         file_map[k] = unicodedata.normalize("NFD", v)
     self.evidence = TriviaQaEvidenceCorpusTxt(file_map)
예제 #5
0
def process(x, verbose=False):
    dataset, filemap, max_num_support, max_tokens, is_web = x
    instances = []
    corpus = TriviaQaEvidenceCorpusTxt(filemap)
    for i, q in enumerate(dataset):
        if verbose and i % 1000 == 0:
            print("%d/%d done" % (i, len(dataset)))
        instances.extend(x for x in convert_triviaqa(
            q, corpus, max_num_support, max_tokens, is_web))
    return instances