class StratifiedParagraphSetDataset(Dataset): """ Sample multiple paragraphs each epoch and include them in the same batch, but stratify the sampling across epochs """ def __init__(self, questions: List[MultiParagraphQuestion], true_len: int, batch_size: int, force_answer: bool, overample_first_answer: List[int], merge: bool): self.overample_first_answer = overample_first_answer self.questions = questions self.merge = merge self.true_len = true_len self.batcher = ClusteredBatcher(batch_size, lambda x: x.n_context_words, truncate_batches=True) self._order = [] self._on = np.zeros(len(questions), dtype=np.int32) for q in questions: if len(q.paragraphs) == 1: self._order.append(np.zeros((1, 1), dtype=np.int32)) continue if force_answer: sample1 = [ i for i, p in enumerate(q.paragraphs) if len(p.answer_spans) > 0 ] else: sample1 = list(range(len(q.paragraphs))) if (len(self.overample_first_answer) > 0 and not (force_answer and len(sample1) == 1) ): # don't bother if there only is one answer ix = 0 for i, p in enumerate(q.paragraphs): if len(p.answer_spans) > 0: sample1 += [i] * self.overample_first_answer[ix] ix += 1 if ix >= len(self.overample_first_answer): break permutations = [] for i in sample1: for j in range(len(q.paragraphs)): if j != i: permutations.append((i, j)) permutations = np.array(permutations, dtype=np.int32) np.random.shuffle(permutations) self._order.append(permutations) def get_vocab(self): voc = set() for q in self.questions: voc.update(q.question) for para in q.paragraphs: voc.update(para.text) return voc def get_spec(self): max_q_len = max(len(q.question) for q in self.questions) max_c_len = max( max(len(p.text) for p in q.paragraphs) for q in self.questions) return ParagraphAndQuestionSpec(None, max_q_len, max_c_len, None) def get_epoch(self): return self._build_expanded_batches(self.questions) def _build_expanded_batches(self, questions): out = [] for i, q in enumerate(questions): order = self._order[i] out.append(ParagraphSelection(q, order[self._on[i]])) self._on[i] += 1 if self._on[i] == len(order): self._on[i] = 0 np.random.shuffle(order) out.sort(key=lambda x: x.n_context_words) group = 0 for selection_batch in self.batcher.get_epoch(out): batch = [] for selected in selection_batch: q = selected.question if self.merge: paras = [q.paragraphs[i] for i in selected.selection] # Sort paragraph my reading order, not rank order paras.sort(key=lambda x: x.get_order()) answer_spans = [] text = [] for para in paras: answer_spans.append(len(text) + para.answer_spans) text += para.text batch.append( ParagraphAndQuestion( text, q.question, TokenSpans(q.answer_text, np.concatenate(answer_spans)), q.question_id)) else: for i in selected.selection: para = q.paragraphs[i] batch.append( para.build_qa_pair(q.question, q.question_id, q.answer_text, group)) group += 1 yield batch def get_samples(self, n_examples): n_batches = self.batcher.epoch_size(n_examples) return self._build_expanded_batches( np.random.choice(self.questions, n_examples, replace=False)), n_batches def percent_filtered(self): return (self.true_len - len(self.questions)) / self.true_len def __len__(self): return self.batcher.epoch_size(len(self.questions))
class RandomParagraphSetDataset(Dataset): """ Sample multiple paragraphs for each question and include them in the same batch """ def __init__(self, questions: List[MultiParagraphQuestion], true_len: int, n_paragraphs: int, batch_size: int, mode: str, force_answer: bool, oversample_first_answer: List[int]): self.mode = mode self.questions = questions self.force_answer = force_answer self.true_len = true_len self.n_paragraphs = n_paragraphs self.oversample_first_answer = oversample_first_answer self._n_pairs = sum( min(len(q.paragraphs), n_paragraphs) for q in questions) self.batcher = ClusteredBatcher(batch_size, lambda x: x.n_context_words, truncate_batches=True) def get_vocab(self): voc = set() for q in self.questions: voc.update(q.question) for para in q.paragraphs: voc.update(para.text) return voc def get_spec(self): max_q_len = max(len(q.question) for q in self.questions) max_c_len = max( max(len(p.text) for p in q.paragraphs) for q in self.questions) return ParagraphAndQuestionSpec( self.batcher.get_fixed_batch_size() if self.mode == "merge" else None, max_q_len, max_c_len, None) def get_epoch(self): return self._build_expanded_batches(self.questions) def _build_expanded_batches(self, questions): # We first pick paragraph(s) for each question in the entire training set so we # can cluster by context length accurately out = [] for q in questions: if len(q.paragraphs) <= self.n_paragraphs: selected = np.arange(len(q.paragraphs)) elif not self.force_answer and len( self.oversample_first_answer) == 0: selected = np.random.choice(len(q.paragraphs), self.n_paragraphs, replace=False) else: if not self.force_answer: raise NotImplementedError() with_answer = [ i for i, p in enumerate(q.paragraphs) if len(p.answer_spans) > 0 ] for ix, over_sample in zip(list(with_answer), self.oversample_first_answer): with_answer += [ix] * over_sample answer_selection = with_answer[np.random.randint( len(with_answer))] other = np.array([ i for i, x in enumerate(q.paragraphs) if i != answer_selection ]) selected = np.random.choice(other, min(len(other), self.n_paragraphs - 1), replace=False) selected = np.insert(selected, 0, answer_selection) if self.mode == "flatten": for i in selected: out.append(q.paragraphs[i].build_qa_pair( q.question, q.question_id, q.answer_text)) else: out.append(ParagraphSelection(q, selected)) out.sort(key=lambda x: x.n_context_words) if self.mode == "flatten": for batch in self.batcher.get_epoch(out): yield batch elif self.mode == "group": group = 0 for selection_batch in self.batcher.get_epoch(out): batch = [] for selected in selection_batch: q = selected.question for i in selected.selection: para = q.paragraphs[i] batch.append( para.build_qa_pair(q.question, q.question_id, q.answer_text, group)) group += 1 yield batch elif self.mode == "merge": for selection_batch in self.batcher.get_epoch(out): batch = [] for selected in selection_batch: q = selected.question paras = [q.paragraphs[i] for i in selected.selection] para = paras[0].merge(paras) batch.append( para.build_qa_pair(q.question, q.question_id, q.answer_text)) yield batch else: raise RuntimeError() def get_samples(self, n_examples): questions = np.random.choice(self.questions, n_examples, replace=False) if self.mode == "flatten": n_batches = self.batcher.epoch_size( sum( min(len(q.paragraphs), self.n_paragraphs) for q in questions)) else: n_batches = self.batcher.epoch_size(n_examples) return self._build_expanded_batches( np.random.choice(questions, n_examples, replace=False)), n_batches def percent_filtered(self): return (self.true_len - len(self.questions)) / self.true_len def __len__(self): if self.mode == "flatten": return self.batcher.epoch_size(self._n_pairs) else: return self.batcher.epoch_size(len(self.questions))