Beispiel #1
0
class RandomParagraphSetDataset(Dataset):
    """
    Sample multiple paragraphs for each question and include them in the same batch
    """

    def __init__(self,
                 questions: List[MultiParagraphQuestion], true_len: int, n_paragraphs: int,
                 batch_size: int, mode: str, force_answer: bool,
                 oversample_first_answer: List[int]):
        self.mode = mode
        self.questions = questions
        self.force_answer = force_answer
        self.true_len = true_len
        self.n_paragraphs = n_paragraphs
        self.oversample_first_answer = oversample_first_answer
        self._n_pairs = sum(min(len(q.paragraphs), n_paragraphs) for q in questions)
        self.batcher = ClusteredBatcher(batch_size, lambda x: x.n_context_words, truncate_batches=True)

    def get_vocab(self):
        voc = set()
        for q in self.questions:
            voc.update(q.question)
            for para in q.paragraphs:
                voc.update(para.text)
        return voc

    def get_spec(self):
        max_q_len = max(len(q.question) for q in self.questions)
        max_c_len = max(max(len(p.text) for p in q.paragraphs) for q in self.questions)
        return ParagraphAndQuestionSpec(self.batcher.get_fixed_batch_size() if self.mode == "merge" else None,
                                        max_q_len, max_c_len, None)

    def get_epoch(self):
        return self._build_expanded_batches(self.questions)

    def _build_expanded_batches(self, questions):
        # We first pick paragraph(s) for each question in the entire training set so we
        # can cluster by context length accurately
        out = []
        for q in questions:
            if len(q.paragraphs) <= self.n_paragraphs:
                selected = np.arange(len(q.paragraphs))
            elif not self.force_answer and len(self.oversample_first_answer) == 0:
                selected = np.random.choice(len(q.paragraphs), self.n_paragraphs, replace=False)
            else:
                if not self.force_answer:
                    raise NotImplementedError()
                with_answer = [i for i, p in enumerate(q.paragraphs) if len(p.answer_spans) > 0]
                for ix, over_sample in zip(list(with_answer), self.oversample_first_answer):
                    with_answer += [ix] * over_sample
                answer_selection = with_answer[np.random.randint(len(with_answer))]
                other = np.array([i for i, x in enumerate(q.paragraphs) if i != answer_selection])
                selected = np.random.choice(other, min(len(other), self.n_paragraphs-1), replace=False)
                selected = np.insert(selected, 0, answer_selection)

            if self.mode == "flatten":
                for i in selected:
                    out.append(q.paragraphs[i].build_qa_pair(q.question, q.question_id, q.answer_text))
            else:
                out.append(ParagraphSelection(q, selected))

        out.sort(key=lambda x: x.n_context_words)

        if self.mode == "flatten":
            for batch in self.batcher.get_epoch(out):
                yield batch
        elif self.mode == "group":
            group = 0
            for selection_batch in self.batcher.get_epoch(out):
                batch = []
                for selected in selection_batch:
                    q = selected.question
                    for i in selected.selection:
                        para = q.paragraphs[i]
                        batch.append(para.build_qa_pair(q.question, q.question_id, q.answer_text, group))
                    group += 1
                yield batch
        elif self.mode == "merge":
            for selection_batch in self.batcher.get_epoch(out):
                batch = []
                for selected in selection_batch:
                    q = selected.question
                    paras = [q.paragraphs[i] for i in selected.selection]
                    para = paras[0].merge(paras)
                    batch.append(para.build_qa_pair(q.question, q.question_id, q.answer_text))
                yield batch
        else:
            raise RuntimeError()

    def get_samples(self, n_examples):
        ## ALON fix for sample datasets
        if n_examples > len(self.questions):
            n_examples = len(self.questions)
        questions = np.random.choice(self.questions, n_examples, replace=False)
        if self.mode == "flatten":
            n_batches = self.batcher.epoch_size(sum(min(len(q.paragraphs), self.n_paragraphs) for q in questions))
        else:
            n_batches = self.batcher.epoch_size(n_examples)
        return self._build_expanded_batches(np.random.choice(questions, n_examples, replace=False)), n_batches

    def percent_filtered(self):
        return (self.true_len - len(self.questions)) / self.true_len

    def __len__(self):
        if self.mode == "flatten":
            return self.batcher.epoch_size(self._n_pairs)
        else:
            return self.batcher.epoch_size(len(self.questions))
Beispiel #2
0
class StratifiedParagraphSetDataset(Dataset):
    """
    Sample multiple paragraphs each epoch and include them in the same batch,
    but stratify the sampling across epochs
    """
    def __init__(self, questions: List[MultiParagraphQuestion], true_len: int,
                 batch_size: int, force_answer: bool,
                 overample_first_answer: List[int], merge: bool):
        self.overample_first_answer = overample_first_answer
        self.questions = questions
        self.merge = merge
        self.true_len = true_len
        self.batcher = ClusteredBatcher(batch_size,
                                        lambda x: x.n_context_words,
                                        truncate_batches=True)
        self._order = []
        self._on = np.zeros(len(questions), dtype=np.int32)
        for q in questions:
            if len(q.paragraphs) == 1:
                self._order.append(np.zeros((1, 1), dtype=np.int32))
                continue
            if force_answer:
                sample1 = [
                    i for i, p in enumerate(q.paragraphs)
                    if len(p.answer_spans) > 0
                ]
            else:
                sample1 = list(range(len(q.paragraphs)))

            if (len(self.overample_first_answer) > 0
                    and not (force_answer and len(sample1) == 1)
                ):  # don't bother if there only is one answer
                ix = 0
                for i, p in enumerate(q.paragraphs):
                    if len(p.answer_spans) > 0:
                        sample1 += [i] * self.overample_first_answer[ix]
                        ix += 1
                        if ix >= len(self.overample_first_answer):
                            break

            permutations = []
            for i in sample1:
                for j in range(len(q.paragraphs)):
                    if j != i:
                        permutations.append((i, j))
            permutations = np.array(permutations, dtype=np.int32)
            np.random.shuffle(permutations)
            self._order.append(permutations)

    def get_vocab(self):
        voc = set()
        for q in self.questions:
            voc.update(q.question)
            for para in q.paragraphs:
                voc.update(para.text)
        return voc

    def get_spec(self):
        max_q_len = max(len(q.question) for q in self.questions)
        max_c_len = max(
            max(len(p.text) for p in q.paragraphs) for q in self.questions)
        return ParagraphAndQuestionSpec(None, max_q_len, max_c_len, None)

    def get_epoch(self):
        return self._build_expanded_batches(self.questions)

    def _build_expanded_batches(self, questions):
        out = []
        for i, q in enumerate(questions):
            order = self._order[i]
            out.append(ParagraphSelection(q, order[self._on[i]]))
            self._on[i] += 1
            if self._on[i] == len(order):
                self._on[i] = 0
                np.random.shuffle(order)

        out.sort(key=lambda x: x.n_context_words)

        group = 0
        for selection_batch in self.batcher.get_epoch(out):
            batch = []
            for selected in selection_batch:
                q = selected.question
                if self.merge:
                    paras = [q.paragraphs[i] for i in selected.selection]
                    # Sort paragraph my reading order, not rank order
                    paras.sort(key=lambda x: x.get_order())
                    answer_spans = []
                    text = []
                    for para in paras:
                        answer_spans.append(len(text) + para.answer_spans)
                        text += para.text
                    batch.append(
                        ParagraphAndQuestion(
                            text, q.question,
                            TokenSpans(q.answer_text,
                                       np.concatenate(answer_spans)),
                            q.question_id))
                else:
                    for i in selected.selection:
                        para = q.paragraphs[i]
                        batch.append(
                            para.build_qa_pair(q.question, q.question_id,
                                               q.answer_text, group))
                    group += 1
            yield batch

    def get_samples(self, n_examples):
        n_batches = self.batcher.epoch_size(n_examples)
        return self._build_expanded_batches(
            np.random.choice(self.questions, n_examples,
                             replace=False)), n_batches

    def percent_filtered(self):
        return (self.true_len - len(self.questions)) / self.true_len

    def __len__(self):
        return self.batcher.epoch_size(len(self.questions))