def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, collate_fn=_batchify, batch_sampler=None): self._dataset = dataset self.collate_fn = collate_fn if batch_sampler is None: if batch_size is None: raise ValueError("batch_size must be specified unless " \ "batch_sampler is specified") if sampler is None: if shuffle: sampler = _sampler.RandomSampler(len(dataset)) else: sampler = _sampler.SequentialSampler(len(dataset)) elif shuffle: raise ValueError( "shuffle must not be specified if sampler is specified") batch_sampler = _sampler.BatchSampler( sampler, batch_size, last_batch if last_batch else 'keep') elif batch_size is not None or shuffle or sampler is not None or \ last_batch is not None: raise ValueError("batch_size, shuffle, sampler and last_batch must " \ "not be specified if batch_sampler is specified.") self._batch_sampler = batch_sampler
def __init__( self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, inter_batchify_fn=None, # for internal dataloader part_num=20, # for part loader num_workers=0): self._dataset = dataset if batch_sampler is None: if batch_size is None: raise ValueError("batch_size must be specified unless " \ "batch_sampler is specified") if sampler is None: if shuffle: sampler = _sampler.RandomSampler(len(dataset)) else: sampler = _sampler.SequentialSampler(len(dataset)) elif shuffle: raise ValueError( "shuffle must not be specified if sampler is specified") batch_sampler = _sampler.BatchSampler( sampler, batch_size, last_batch if last_batch else 'keep') elif batch_size is not None or shuffle or sampler is not None or \ last_batch is not None: raise ValueError("batch_size, shuffle, sampler and last_batch must " \ "not be specified if batch_sampler is specified.") self._batch_sampler = batch_sampler self._num_workers = num_workers if batchify_fn is None: #if num_workers > 0: # self.batchify_fn = default_mp_batchify_fn #else: # self.batchify_fn = default_batchify_fn raise Exception('no batchify_fn is specified') else: self.batchify_fn = batchify_fn self.inter_batchify_fn = inter_batchify_fn self.batch_size = batch_size self.shuffle = shuffle self.num = None #num if self.num is None: self.num = len(self._dataset) / self.batch_size self.cache_n = self._num_workers #cache_n self.cache_i = 0 self.cache_num = None #cache_num if self.cache_num is None: self.init_cache_num() self.data, self.label = self._dataset[0] self.batch_data_shape = (self.batch_size, ) + self.data.shape self.batch_label_shape = (self.batch_size, ) self.part_num = part_num # num of batches in shared memory once time self.init_data_shm((self.part_num, ) + self.batch_data_shape, np.float32, None) self.init_label_shm((self.part_num, ) + self.batch_label_shape, np.float32, None) self.init_qs()