def __iter__(self): while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): if self.sort_within_batch: # NOTE: `rnn.pack_padded_sequence` requires that a minibatch # be sorted by decreasing order, which requires reversing # relative to typical sort keys if self.sort: minibatch.reverse() else: minibatch.sort(key=self.sort_key, reverse=True) batch = Batch(minibatch, self.dataset, self.device) if self._iterations_this_epoch > idx: continue self.iterations += 1 self._iterations_this_epoch += 1 # should we have many batches or we should have one long batch with many windows batch_size: int = len(batch) dataset = Dataset(examples=self.dataset.examples, fields=self.fields) if self.train: yield from self.__bptt_2_minibatch(batch, batch_size, dataset) else: if 'bow' in dataset.fields.keys(): batch.target_bow = batch.bow[:, :, 1] batch.bow = batch.bow[:, :, 0] else: batch.text = (batch.text[0], batch.text[2]) batch.target_time = batch.time[0][:, :, -1] batch.time = (batch.time[0][:, :, :2], batch.time[1]) yield batch if not self.repeat: return
def __iter__(self): while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): if self.sort_within_batch: # NOTE: `rnn.pack_padded_sequence` requires that a minibatch # be sorted by decreasing order, which requires reversing # relative to typical sort keys if self.sort: minibatch.reverse() else: minibatch.sort(key=self.sort_key, reverse=True) batch = Batch(minibatch, self.dataset, self.device) if self._iterations_this_epoch > idx: continue self.iterations += 1 self._iterations_this_epoch += 1 batch_size: int = len(batch) dataset = Dataset(examples=self.dataset.examples, fields=[ ('time', self.dataset.fields['time']), ('mark', self.dataset.fields['mark']), ('target_time', Field(use_vocab=False)), ('target_mark', Field(use_vocab=False))]) if self.train: seq_len, time, mark = self.__series_2_bptt(batch, batch_size) yield (Batch.fromvars( dataset, batch_size, time=(ti[:, :, :2], l), mark=m[:, :, 0], target_time=ti[:, :, -1], target_mark=m[:, :, -1]) for ti, l, m in zip(time, seq_len, mark)) else: batch.target_time = batch.time[0][:, :, 2] batch.time = (batch.time[0][:, :, :2], batch.time[1]) batch.target_mark = batch.mark[:, :, 1] batch.mark = batch.mark[:, :, 0] yield batch if not self.repeat: return