def create_batches(self): if self.train: if self.yield_raw_example: self.batches = batch_iter( self.data(), 1, batch_size_fn=None, batch_size_multiple=1) else: self.batches = _pool( self.data(), self.batch_size, self.batch_size_fn, self.batch_size_multiple, self.sort_key, self.random_shuffler, self.pool_factor, self.dataset # @memray ) else: self.batches = [] for b in batch_iter( self.data(), self.batch_size, batch_size_fn=self.batch_size_fn, batch_size_multiple=self.batch_size_multiple): # if it's keyphrase dataset, a preprocess to targets should act beforehand. if isinstance(self.dataset, KeyphraseDataset): b = keyphrase_dataset.process_multiple_tgts(b, self.dataset.tgt_type) # @memray: to keep the original order of test data, only sort inside a batch (is it necessary? why not just sort it before feeding the model?) # self.batches.append(sorted(b, key=self.sort_key)) self.batches.append(b)
def _pool(data, batch_size, batch_size_fn, batch_size_multiple, sort_key, random_shuffler, pool_factor, dataset=None): for p in torchtext.data.batch(data, batch_size * pool_factor, batch_size_fn=batch_size_fn): # if it's keyphrase dataset, a preprocess to targets should act beforehand. if dataset and isinstance(dataset, KeyphraseDataset): p = keyphrase_dataset.process_multiple_tgts(p, dataset.tgt_type) # @memray: split each big batch into final mini-batches # batch_size_fn=max_tok_len() for train, counting real batch size (num_batch * max(#words in src/tgt)) # sort the data before splitting p_batch = list( batch_iter(sorted(p, key=sort_key), batch_size, batch_size_fn=batch_size_fn, batch_size_multiple=batch_size_multiple)) # shuffle samples in a minibatch before returning for b in random_shuffler(p_batch): yield b