Esempio n. 1
0
 def create_batches(self):
     if self.train:
         if self.yield_raw_example:
             self.batches = batch_iter(
                 self.data(),
                 1,
                 batch_size_fn=None,
                 batch_size_multiple=1)
         else:
             self.batches = _pool(
                 self.data(),
                 self.batch_size,
                 self.batch_size_fn,
                 self.batch_size_multiple,
                 self.sort_key,
                 self.random_shuffler,
                 self.pool_factor,
                 self.dataset # @memray
             )
     else:
         self.batches = []
         for b in batch_iter(
                 self.data(),
                 self.batch_size,
                 batch_size_fn=self.batch_size_fn,
                 batch_size_multiple=self.batch_size_multiple):
             # if it's keyphrase dataset, a preprocess to targets should act beforehand.
             if isinstance(self.dataset, KeyphraseDataset):
                 b = keyphrase_dataset.process_multiple_tgts(b, self.dataset.tgt_type)
             # @memray: to keep the original order of test data, only sort inside a batch (is it necessary? why not just sort it before feeding the model?)
             # self.batches.append(sorted(b, key=self.sort_key))
             self.batches.append(b)
def _pool(data,
          batch_size,
          batch_size_fn,
          batch_size_multiple,
          sort_key,
          random_shuffler,
          pool_factor,
          dataset=None):
    for p in torchtext.data.batch(data,
                                  batch_size * pool_factor,
                                  batch_size_fn=batch_size_fn):
        # if it's keyphrase dataset, a preprocess to targets should act beforehand.
        if dataset and isinstance(dataset, KeyphraseDataset):
            p = keyphrase_dataset.process_multiple_tgts(p, dataset.tgt_type)
        # @memray: split each big batch into final mini-batches
        # batch_size_fn=max_tok_len() for train, counting real batch size (num_batch * max(#words in src/tgt))
        # sort the data before splitting
        p_batch = list(
            batch_iter(sorted(p, key=sort_key),
                       batch_size,
                       batch_size_fn=batch_size_fn,
                       batch_size_multiple=batch_size_multiple))
        # shuffle samples in a minibatch before returning
        for b in random_shuffler(p_batch):
            yield b