def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields["text"]
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device)
        data = (data.stack(
            ("seqlen", "batch"),
            "flat").split("flat", ("batch", "seqlen"),
                          batch=self.batch_size).transpose("seqlen", "batch"))

        dataset = Dataset(examples=self.dataset.examples,
                          fields=[("text", TEXT), ("target", TEXT)])
        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(
                    dataset,
                    self.batch_size,
                    text=data.narrow("seqlen", i, seq_len),
                    target=data.narrow("seqlen", i + 1, seq_len),
                )

            if not self.repeat:
                return
Example #2
0
    def __iter__(self) -> Iterator[Batch]:
        """Same iterator almost as bucket iterator"""
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                if self.sort_within_batch:
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)

                context, response, targets = self.process_minibatch(minibatch)
                for index in range(context.shape[0]):
                    # do not yield if the target is just padding (does not provide anything to training)
                    if (targets[index] == self.text_field.vocab.stoi[self.text_field.pad_token]).all():
                        continue
                    # skip examples with contexts that won't fit in gpu memory
                    if np.prod(context[:index + 1].shape) > self.max_context_size:
                        continue
                    yield Batch.fromvars(dataset=self.dataset, batch_size=len(minibatch),
                                         train=self.train,
                                         context=context[:index + 1],
                                         response=response[index],
                                         targets=targets[index]
                                         )
            if not self.repeat:
                raise StopIteration
Example #3
0
    def __iter__(self) -> Iter[Batch]:
        """Same iterator almost as bucket iterator"""
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                if self.sort_within_batch:
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)

                context, response, targets = self.process_minibatch(minibatch)
                yield Batch.fromvars(dataset=self.dataset,
                                     batch_size=len(minibatch),
                                     train=self.train,
                                     context=context,
                                     response=response,
                                     targets=targets)
            if not self.repeat:
                raise StopIteration
Example #4
0
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None

        num_batches = math.ceil(len(text) / self.batch_size)
        pad_amount = int(num_batches * self.batch_size - len(text))
        text += [TEXT.pad_token] * pad_amount

        data = TEXT.numericalize([text], device=self.device)
        data = data.stack(('seqlen', 'batch'), 'flat') \
                   .split('flat', ('batch', 'seqlen'), batch=self.batch_size) \
                   .transpose('seqlen', 'batch')

        fields = [('text', TEXT), ('target', TEXT)]
        dataset = Dataset(examples=self.dataset.examples, fields=fields)

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     text=data.narrow('seqlen', i, seq_len),
                                     target=data.narrow(
                                         'seqlen', i + 1, seq_len))

            if not self.repeat:
                return
 def __iter__(self):
     text = self.dataset[0].text
     TEXT = self.dataset.fields['text']
     TEXT.eos_token = None
     text = text + ([TEXT.pad_token] * int(
         math.ceil(len(text) / self.batch_size) * self.batch_size -
         len(text)))
     data = TEXT.numericalize([text], device=self.device)
     data = data.view(self.batch_size, -1).t().contiguous()
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('text', TEXT), ('target', TEXT)])
     while True:
         for i in range(0, len(self) * self.bptt_len, self.bptt_len):
             self.iterations += 1
             seq_len = min(self.bptt_len, len(data) - i - 1)
             batch_text = data[i:i + seq_len]
             batch_target = data[i + 1:i + 1 + seq_len]
             if TEXT.batch_first:
                 batch_text = batch_text.t().contiguous()
                 batch_target = batch_target.t().contiguous()
             yield Batch.fromvars(dataset,
                                  self.batch_size,
                                  text=batch_text,
                                  target=batch_target)
         if not self.repeat:
             return
Example #6
0
 def consume_buffer(self):
     cur_text_buffer = self.get_contiguous_buffer()
     data, dataset = self.prepare_text_buffer(cur_text_buffer)
     t_len = self.get_len(cur_text_buffer)
     for batch_text, batch_target in self.consume_data(data, t_len):
         kwargs = {self.field_name: batch_text, 'target': batch_target}
         yield Batch.fromvars(
             dataset,
             self.batch_size,
             **kwargs
         )
Example #7
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                source_batch = [m.source for m in minibatch]
                source_mask = sequence_mask(source_batch)
                if not self.mode == "infer":
                    target_batch = [m.target for m in minibatch]
                    label_batch = [m.label for m in minibatch]
                    target_mask = sequence_mask(target_batch)
                    yield Batch.fromvars(
                        self.dataset,
                        self.batch_size,
                        source=postprocessing(source_batch, self.params),
                        source_mask=source_mask,
                        target=postprocessing(target_batch, self.params),
                        target_mask=target_mask,
                        label=postprocessing(label_batch, self.params))
                else:
                    yield Batch.fromvars(self.dataset,
                                         self.batch_size,
                                         source=postprocessing(
                                             source_batch, self.params),
                                         source_mask=source_mask)

            if not self.repeat:
                return
Example #8
0
 def __iter__(self):
     text = getattr(self.dataset[0], self.field_name)
     data, dataset = self.prepare_text(text)
     while True:
         for i in range(0, len(self) * self.cur_bptt_len, self.cur_bptt_len):
             self.iterations += 1
             seq_len = min(self.cur_bptt_len, len(data) - i - 1)
             batch_text = data[i:i + seq_len]
             batch_target = data[i + 1:i + 1 + seq_len]
             if self.batch_first:
                 batch_text = batch_text.t().contiguous()
                 batch_target = batch_target.t().contiguous()
             yield Batch.fromvars(
                 dataset,
                 self.batch_size,
                 text=batch_text,
                 target=batch_target
             )
         if not self.repeat:
             return
 def __iter__(self):
     text = self.dataset[0].text
     TEXT = self.dataset.fields['text']
     TEXT.eos_token = None
     pad_num = int(math.ceil(len(text) / self.batch_size) * self.batch_size \
                   - len(text))
     text = text + ([TEXT.pad_token] * pad_num)
     data = TEXT.numericalize([text], device=self.device)
     data = data.view(self.batch_size, -1).contiguous()
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('text', TEXT), ('target', TEXT)])
     while True:
         for i in range(0, len(self) * self.bptt_len, self.bptt_len):
             self.iterations += 1
             seq_len = self.bptt_len
             yield Batch.fromvars(dataset,
                                  self.batch_size,
                                  text=data[:, i:i + seq_len],
                                  target=data[:, i + 1:i + 1 + seq_len])
         if not self.repeat:
             return
Example #10
0
    def _transform(self, batch):

        src, src_lens = batch.src
        src_size = src.size()
        src = torch.LongTensor([
            self.src_b2s[i] for i in src.data.view(-1).tolist()
        ]).view(src_size)

        trg, trg_lens = batch.trg
        trg_size = trg.size()
        trg = torch.LongTensor([
            self.trg_b2s[i] for i in trg.data.view(-1).tolist()
        ]).view(trg_size)

        if self.use_cuda:
            src = src.cuda()
            trg = trg.cuda()

        return Batch.fromvars(batch.dataset,
                              batch.batch_size,
                              batch.train,
                              src=(src, src_lens),
                              trg=(trg, trg_lens))