def __iter__(self): text = self.dataset[0].text TEXT = self.dataset.fields["text"] TEXT.eos_token = None text = text + ([TEXT.pad_token] * int( math.ceil(len(text) / self.batch_size) * self.batch_size - len(text))) data = TEXT.numericalize([text], device=self.device) data = (data.stack( ("seqlen", "batch"), "flat").split("flat", ("batch", "seqlen"), batch=self.batch_size).transpose("seqlen", "batch")) dataset = Dataset(examples=self.dataset.examples, fields=[("text", TEXT), ("target", TEXT)]) while True: for i in range(0, len(self) * self.bptt_len, self.bptt_len): self.iterations += 1 seq_len = min(self.bptt_len, len(data) - i - 1) yield Batch.fromvars( dataset, self.batch_size, text=data.narrow("seqlen", i, seq_len), target=data.narrow("seqlen", i + 1, seq_len), ) if not self.repeat: return
def __iter__(self) -> Iterator[Batch]: """Same iterator almost as bucket iterator""" while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): # fast-forward if loaded from state if self._iterations_this_epoch > idx: continue self.iterations += 1 self._iterations_this_epoch += 1 if self.sort_within_batch: if self.sort: minibatch.reverse() else: minibatch.sort(key=self.sort_key, reverse=True) context, response, targets = self.process_minibatch(minibatch) for index in range(context.shape[0]): # do not yield if the target is just padding (does not provide anything to training) if (targets[index] == self.text_field.vocab.stoi[self.text_field.pad_token]).all(): continue # skip examples with contexts that won't fit in gpu memory if np.prod(context[:index + 1].shape) > self.max_context_size: continue yield Batch.fromvars(dataset=self.dataset, batch_size=len(minibatch), train=self.train, context=context[:index + 1], response=response[index], targets=targets[index] ) if not self.repeat: raise StopIteration
def __iter__(self) -> Iter[Batch]: """Same iterator almost as bucket iterator""" while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): # fast-forward if loaded from state if self._iterations_this_epoch > idx: continue self.iterations += 1 self._iterations_this_epoch += 1 if self.sort_within_batch: if self.sort: minibatch.reverse() else: minibatch.sort(key=self.sort_key, reverse=True) context, response, targets = self.process_minibatch(minibatch) yield Batch.fromvars(dataset=self.dataset, batch_size=len(minibatch), train=self.train, context=context, response=response, targets=targets) if not self.repeat: raise StopIteration
def __iter__(self): text = self.dataset[0].text TEXT = self.dataset.fields['text'] TEXT.eos_token = None num_batches = math.ceil(len(text) / self.batch_size) pad_amount = int(num_batches * self.batch_size - len(text)) text += [TEXT.pad_token] * pad_amount data = TEXT.numericalize([text], device=self.device) data = data.stack(('seqlen', 'batch'), 'flat') \ .split('flat', ('batch', 'seqlen'), batch=self.batch_size) \ .transpose('seqlen', 'batch') fields = [('text', TEXT), ('target', TEXT)] dataset = Dataset(examples=self.dataset.examples, fields=fields) while True: for i in range(0, len(self) * self.bptt_len, self.bptt_len): self.iterations += 1 seq_len = min(self.bptt_len, len(data) - i - 1) yield Batch.fromvars(dataset, self.batch_size, text=data.narrow('seqlen', i, seq_len), target=data.narrow( 'seqlen', i + 1, seq_len)) if not self.repeat: return
def __iter__(self): text = self.dataset[0].text TEXT = self.dataset.fields['text'] TEXT.eos_token = None text = text + ([TEXT.pad_token] * int( math.ceil(len(text) / self.batch_size) * self.batch_size - len(text))) data = TEXT.numericalize([text], device=self.device) data = data.view(self.batch_size, -1).t().contiguous() dataset = Dataset(examples=self.dataset.examples, fields=[('text', TEXT), ('target', TEXT)]) while True: for i in range(0, len(self) * self.bptt_len, self.bptt_len): self.iterations += 1 seq_len = min(self.bptt_len, len(data) - i - 1) batch_text = data[i:i + seq_len] batch_target = data[i + 1:i + 1 + seq_len] if TEXT.batch_first: batch_text = batch_text.t().contiguous() batch_target = batch_target.t().contiguous() yield Batch.fromvars(dataset, self.batch_size, text=batch_text, target=batch_target) if not self.repeat: return
def consume_buffer(self): cur_text_buffer = self.get_contiguous_buffer() data, dataset = self.prepare_text_buffer(cur_text_buffer) t_len = self.get_len(cur_text_buffer) for batch_text, batch_target in self.consume_data(data, t_len): kwargs = {self.field_name: batch_text, 'target': batch_target} yield Batch.fromvars( dataset, self.batch_size, **kwargs )
def __iter__(self): while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): # fast-forward if loaded from state if self._iterations_this_epoch > idx: continue self.iterations += 1 self._iterations_this_epoch += 1 if self.sort_within_batch: # NOTE: `rnn.pack_padded_sequence` requires that a minibatch # be sorted by decreasing order, which requires reversing # relative to typical sort keys if self.sort: minibatch.reverse() else: minibatch.sort(key=self.sort_key, reverse=True) source_batch = [m.source for m in minibatch] source_mask = sequence_mask(source_batch) if not self.mode == "infer": target_batch = [m.target for m in minibatch] label_batch = [m.label for m in minibatch] target_mask = sequence_mask(target_batch) yield Batch.fromvars( self.dataset, self.batch_size, source=postprocessing(source_batch, self.params), source_mask=source_mask, target=postprocessing(target_batch, self.params), target_mask=target_mask, label=postprocessing(label_batch, self.params)) else: yield Batch.fromvars(self.dataset, self.batch_size, source=postprocessing( source_batch, self.params), source_mask=source_mask) if not self.repeat: return
def __iter__(self): text = getattr(self.dataset[0], self.field_name) data, dataset = self.prepare_text(text) while True: for i in range(0, len(self) * self.cur_bptt_len, self.cur_bptt_len): self.iterations += 1 seq_len = min(self.cur_bptt_len, len(data) - i - 1) batch_text = data[i:i + seq_len] batch_target = data[i + 1:i + 1 + seq_len] if self.batch_first: batch_text = batch_text.t().contiguous() batch_target = batch_target.t().contiguous() yield Batch.fromvars( dataset, self.batch_size, text=batch_text, target=batch_target ) if not self.repeat: return
def __iter__(self): text = self.dataset[0].text TEXT = self.dataset.fields['text'] TEXT.eos_token = None pad_num = int(math.ceil(len(text) / self.batch_size) * self.batch_size \ - len(text)) text = text + ([TEXT.pad_token] * pad_num) data = TEXT.numericalize([text], device=self.device) data = data.view(self.batch_size, -1).contiguous() dataset = Dataset(examples=self.dataset.examples, fields=[('text', TEXT), ('target', TEXT)]) while True: for i in range(0, len(self) * self.bptt_len, self.bptt_len): self.iterations += 1 seq_len = self.bptt_len yield Batch.fromvars(dataset, self.batch_size, text=data[:, i:i + seq_len], target=data[:, i + 1:i + 1 + seq_len]) if not self.repeat: return
def _transform(self, batch): src, src_lens = batch.src src_size = src.size() src = torch.LongTensor([ self.src_b2s[i] for i in src.data.view(-1).tolist() ]).view(src_size) trg, trg_lens = batch.trg trg_size = trg.size() trg = torch.LongTensor([ self.trg_b2s[i] for i in trg.data.view(-1).tolist() ]).view(trg_size) if self.use_cuda: src = src.cuda() trg = trg.cuda() return Batch.fromvars(batch.dataset, batch.batch_size, batch.train, src=(src, src_lens), trg=(trg, trg_lens))