예제 #1
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                # should we have many batches or we should have one long batch with many windows
                batch_size: int = len(batch)

                dataset = Dataset(examples=self.dataset.examples, fields=self.fields)
                if self.train:
                    yield from self.__bptt_2_minibatch(batch, batch_size, dataset)
                else:
                    if 'bow' in dataset.fields.keys():
                        batch.target_bow = batch.bow[:, :, 1]
                        batch.bow = batch.bow[:, :, 0]
                    else:
                        batch.text = (batch.text[0], batch.text[2])
                    batch.target_time = batch.time[0][:, :, -1]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    yield batch

            if not self.repeat:
                return
예제 #2
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                batch_size: int = len(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('mark', self.dataset.fields['mark']),
                    ('target_time', Field(use_vocab=False)), ('target_mark', Field(use_vocab=False))])
                if self.train:
                    seq_len, time, mark = self.__series_2_bptt(batch, batch_size)
                    yield (Batch.fromvars(
                            dataset, batch_size,
                            time=(ti[:, :, :2], l),
                            mark=m[:, :, 0],
                            target_time=ti[:, :, -1],
                            target_mark=m[:, :, -1]) for ti, l, m in zip(time, seq_len, mark))
                else:
                    batch.target_time = batch.time[0][:, :, 2]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    batch.target_mark = batch.mark[:, :, 1]
                    batch.mark = batch.mark[:, :, 0]
                    yield batch

            if not self.repeat:
                return