class AudioBatchData(Dataset): def __init__(self, path, sizeWindow, seqNames, phoneLabelsDict, nSpeakers, nProcessLoader=50, MAX_SIZE_LOADED=4000000000): """ Args: - path (string): path to the training dataset - sizeWindow (int): size of the sliding window - seqNames (list): sequences to load - phoneLabelsDict (dictionnary): if not None, a dictionnary with the following entries "step": size of a labelled window "$SEQ_NAME": list of phonem labels for the sequence $SEQ_NAME - nSpeakers (int): number of speakers to expect. - nProcessLoader (int): number of processes to call when loading the data from the disk - MAX_SIZE_LOADED (int): target maximal size of the floating array containing all loaded data. """ self.MAX_SIZE_LOADED = MAX_SIZE_LOADED self.nProcessLoader = nProcessLoader self.dbPath = Path(path) self.sizeWindow = sizeWindow self.seqNames = [(s, self.dbPath / x) for s, x in seqNames] self.reload_pool = Pool(nProcessLoader) self.prepare() self.speakers = list(range(nSpeakers)) self.data = [] self.phoneSize = 0 if phoneLabelsDict is None else \ phoneLabelsDict["step"] self.phoneStep = 0 if phoneLabelsDict is None else \ self.sizeWindow // self.phoneSize self.phoneLabelsDict = deepcopy(phoneLabelsDict) self.loadNextPack(first=True) self.loadNextPack() self.doubleLabels = False def resetPhoneLabels(self, newPhoneLabels, step): self.phoneSize = step self.phoneStep = self.sizeWindow // self.phoneSize self.phoneLabelsDict = deepcopy(newPhoneLabels) self.loadNextPack() def splitSeqTags(seqName): path = os.path.normpath(seqName) return path.split(os.sep) def getSeqNames(self): return [str(x[1]) for x in self.seqNames] def clear(self): if 'data' in self.__dict__: del self.data if 'speakerLabel' in self.__dict__: del self.speakerLabel if 'phoneLabels' in self.__dict__: del self.phoneLabels if 'seqLabel' in self.__dict__: del self.seqLabel def prepare(self): random.shuffle(self.seqNames) start_time = time.time() print("Checking length...") allLength = self.reload_pool.map(extractLength, self.seqNames) self.packageIndex, self.totSize = [], 0 start, packageSize = 0, 0 for index, length in tqdm.tqdm(enumerate(allLength)): packageSize += length if packageSize > self.MAX_SIZE_LOADED: self.packageIndex.append([start, index]) self.totSize += packageSize start, packageSize = index, 0 if packageSize > 0: self.packageIndex.append([start, len(self.seqNames)]) self.totSize += packageSize print(f"Done, elapsed: {time.time() - start_time:.3f} seconds") print(f'Scanned {len(self.seqNames)} sequences ' f'in {time.time() - start_time:.2f} seconds') print(f"{len(self.packageIndex)} chunks computed") self.currentPack = -1 self.nextPack = 0 def getNPacks(self): return len(self.packageIndex) def loadNextPack(self, first=False): self.clear() if not first: self.currentPack = self.nextPack start_time = time.time() print('Joining pool') self.r.wait() print(f'Joined process, elapsed={time.time()-start_time:.3f} secs') self.nextData = self.r.get() self.parseNextDataBlock() del self.nextData self.nextPack = (self.currentPack + 1) % len(self.packageIndex) seqStart, seqEnd = self.packageIndex[self.nextPack] if self.nextPack == 0 and len(self.packageIndex) > 1: self.prepare() self.r = self.reload_pool.map_async(loadFile, self.seqNames[seqStart:seqEnd]) def parseNextDataBlock(self): # Labels self.speakerLabel = [0] self.seqLabel = [0] self.phoneLabels = [] speakerSize = 0 indexSpeaker = 0 # To accelerate the process a bit self.nextData.sort(key=lambda x: (x[0], x[1])) tmpData = [] for speaker, seqName, seq in self.nextData: while self.speakers[indexSpeaker] < speaker: indexSpeaker += 1 self.speakerLabel.append(speakerSize) if self.speakers[indexSpeaker] != speaker: raise ValueError(f'{speaker} invalid speaker') if self.phoneLabelsDict is not None: self.phoneLabels += self.phoneLabelsDict[seqName] newSize = len(self.phoneLabelsDict[seqName]) * self.phoneSize seq = seq[:newSize] sizeSeq = seq.size(0) tmpData.append(seq) self.seqLabel.append(self.seqLabel[-1] + sizeSeq) speakerSize += sizeSeq del seq self.speakerLabel.append(speakerSize) self.data = torch.cat(tmpData, dim=0) def getPhonem(self, idx): idPhone = idx // self.phoneSize return self.phoneLabels[idPhone:(idPhone + self.phoneStep)] def getSpeakerLabel(self, idx): idSpeaker = next( x[0] for x in enumerate(self.speakerLabel) if x[1] > idx) - 1 return idSpeaker def __len__(self): return self.totSize // self.sizeWindow def __getitem__(self, idx): if idx < 0 or idx >= len(self.data) - self.sizeWindow - 1: print(idx) outData = self.data[idx:(self.sizeWindow + idx)].view(1, -1) label = torch.tensor(self.getSpeakerLabel(idx), dtype=torch.long) if self.phoneSize > 0: label_phone = torch.tensor(self.getPhonem(idx), dtype=torch.long) if not self.doubleLabels: label = label_phone else: label_phone = torch.zeros(1) if self.doubleLabels: return outData, label, label_phone return outData, label def getNSpeakers(self): return len(self.speakers) def getNSeqs(self): return len(self.seqLabel) - 1 def getNLoadsPerEpoch(self): return len(self.packageIndex) def getBaseSampler(self, type, batchSize, offset): if type == "samespeaker": return SameSpeakerSampler(batchSize, self.speakerLabel, self.sizeWindow, offset) if type == "samesequence": return SameSpeakerSampler(batchSize, self.seqLabel, self.sizeWindow, offset) if type == "sequential": return SequentialSampler(len(self.data), self.sizeWindow, offset, batchSize) sampler = UniformAudioSampler(len(self.data), self.sizeWindow, offset) return BatchSampler(sampler, batchSize, True) def getDataLoader(self, batchSize, type, randomOffset, numWorkers=0, onLoop=-1): r""" Get a batch sampler for the current dataset. Args: - batchSize (int): batch size - groupSize (int): in the case of type in ["samespeaker", "samesequence"] number of items sharing a same label in the group (see AudioBatchSampler) - type (string): type == "speaker": grouped sampler speaker-wise type == "sequence": grouped sampler sequence-wise type == "sequential": sequential sampling else: uniform random sampling of the full audio vector - randomOffset (bool): if True add a random offset to the sampler at the begining of each iteration """ nLoops = len(self.packageIndex) totSize = self.totSize // (self.sizeWindow * batchSize) if onLoop >= 0: self.currentPack = onLoop - 1 self.loadNextPack() nLoops = 1 def samplerCall(): offset = random.randint(0, self.sizeWindow // 2) \ if randomOffset else 0 return self.getBaseSampler(type, batchSize, offset) return AudioLoader(self, samplerCall, nLoops, self.loadNextPack, totSize, numWorkers)
def train(self, n_process=None): assert n_process is not None pool = Pool(processes=n_process) outer_lr_scheduler = PiecewiseSchedule( [(0, self.outer_lr), (int(self.outer_n_epoch / 2), self.outer_lr * 0.1)], outside_value=self.outer_lr * 0.1) evoluted_loss = loss_net(self.inner_args['nclass']) evoluted_loss = torch.nn.DataParallel(evoluted_loss).cuda() evoluted_loss.train() for epoch in range(self.outer_n_epoch): phi_state_dict = evoluted_loss.state_dict() phi_noise = [] for i in range(self.outer_n_noise): noise = dict() for key in evoluted_loss.state_dict().keys(): noise[key] = self.outer_std * torch.randn_like( evoluted_loss.state_dict()[key]) + phi_state_dict[key] phi_noise.append(noise) start_time = time.time() assert self.outer_n_worker % self.outer_n_noise == 0 args = ((epoch, i, phi_noise[i], self.target_model, self.train_dataset, self.test_dataset, self.inner_args) for i in range(self.outer_n_noise)) results = pool.map_async(inner_train, args) results = results.get() result = [np.mean(r['accuracy']) for r in results] print( 'Epoch %d, mean accuracy: %f, max accuracy: %f(%d), min accuracy: %f(%d)' % (epoch, np.mean(result), np.max(result), result.index(np.max(result)), np.min(result), result.index(np.min(result)))) print('All inner loops completed, returns gathered ({:.2f} sec).'. format(time.time() - start_time)) print('\n') for key in evoluted_loss.state_dict(): grad = 0 for i in range(self.outer_n_noise): grad += result[i] * phi_noise[i][key].cpu( ) # - self.outer_l2 * evoluted_loss.state_dict()[key].cpu() lr = outer_lr_scheduler.value(epoch) adam = Adam(shape=grad.shape, stepsize=lr) evoluted_loss.state_dict()[key] -= adam.step( grad / self.outer_n_worker).cuda() if self.save_model: torch.save( { 'epoch': epoch + 1, 'state_dict': evoluted_loss.state_dict() }, self.log.path_model) self.log.log_tabular('epoch', epoch) for i in range(len(result)): self.log.log_tabular('reward %d' % i, result[i]) self.log.dump_tabular() pool.close()
class LibriSelectionDataset(Dataset): """LibriSpeech Selection data from sincnet paper.""" def __init__(self, sizeWindow=20480, db_wav_root=DB_WAV_ROOT, fps_list=str(), label_path=str(), nSpeakers=-1, n_process_loader=50, MAX_SIZE_LOADED=4000000000): """Init. Args: - sizeWindow (int): size of the sliding window - db_wav_path (str): - fps_list_path (str): - label_path (str): - n_process_loader (int): - MAX_SIZE_LOADED (int): target maximal size of the floating array containing all loaded data. """ self.MAX_SIZE_LOADED = MAX_SIZE_LOADED self.n_process_loader = n_process_loader self.db_wav_root = Path(db_wav_root) self.sizeWindow = sizeWindow """Parsing customized to Libri-selection dataset.""" fps_name_only = get_fps_from_txt(fps_list) label_dict = np.load(label_path, allow_pickle=True)[()] self.all_labels_fps = [(label_dict[x], Path(db_wav_root) / Path(x)) for x in fps_name_only] self.reload_pool = Pool(n_process_loader) self.prepare( ) # Split large number of files into packages, and set {self.currentPack=-1, self.nextPack=0} if nSpeakers == -1: nSpeakers = len(set(label_dict.values())) self.speakers = list(range(nSpeakers)) self.data = [] self.loadNextPack(first=True) self.loadNextPack() def __len__(self): """Get length.""" return self.totSize // self.sizeWindow def prepare(self): """Prepare.""" random.shuffle(self.all_labels_fps) start_time = time.time() print("Checking length...") allLength = self.reload_pool.map(extractLength, self.all_labels_fps) self.packageIndex, self.totSize = [], 0 start, packageSize = 0, 0 for index, length in tqdm.tqdm(enumerate(allLength)): packageSize += length if packageSize > self.MAX_SIZE_LOADED: self.packageIndex.append([start, index]) self.totSize += packageSize start, packageSize = index, 0 if packageSize > 0: self.packageIndex.append([start, len(self.all_labels_fps)]) self.totSize += packageSize print(f"Done, elapsed: {time.time() - start_time:.3f} seconds") print(f'Scanned {len(self.all_labels_fps)} sequences ' f'in {time.time() - start_time:.2f} seconds') print(f"{len(self.packageIndex)} chunks computed") self.currentPack = -1 self.nextPack = 0 def clear(self): """Clear.""" if 'data' in self.__dict__: del self.data if 'speakerLabel' in self.__dict__: del self.speakerLabel if 'seqLabel' in self.__dict__: del self.seqLabel def getNPacks(self): """Get N packs.""" return len(self.packageIndex) def getNSeqs(self): """Get N seqs.""" return len(self.seqLabel) - 1 def getNLoadsPerEpoch(self): """Get N loads per epoch.""" return len(self.packageIndex) def getSpeakerLabel(self, idx): idSpeaker = next( x[0] for x in enumerate(self.speakerLabel) if x[1] > idx) - 1 return idSpeaker def loadNextPack(self, first=False): """Load next pack.""" self.clear() if not first: self.currentPack = self.nextPack start_time = time.time() print('Joining pool') self.r.wait() print(f'Joined process, elapsed={time.time()-start_time:.3f} secs') self.nextData = self.r.get() self.parseNextDataBlock() del self.nextData self.nextPack = (self.currentPack + 1) % len(self.packageIndex) seqStart, seqEnd = self.packageIndex[self.nextPack] if self.nextPack == 0 and len(self.packageIndex) > 1: self.prepare() """map() blocks until complete, map_async() returns immediately and schedules a callback to be run on the result.""" self.r = self.reload_pool.map_async( loadFile, self.all_labels_fps[seqStart:seqEnd]) """loadFile: return speaker, seqName, seq""" def parseNextDataBlock(self): """Parse next data block.""" # Labels self.speakerLabel = [0] self.seqLabel = [0] speakerSize = 0 indexSpeaker = 0 # To accelerate the process a bit self.nextData.sort(key=lambda x: (x[0], x[1])) """ nextData[0] = (1243, '4910-14124-0001-1', tensor([-0.0089, -0.0084, -0.0079, ..., -0.0015, -0.0056, 0.0047])) """ tmpData = [] for speaker, seqName, seq in self.nextData: while self.speakers[indexSpeaker] < speaker: indexSpeaker += 1 self.speakerLabel.append(speakerSize) if self.speakers[indexSpeaker] != speaker: raise ValueError(f'{speaker} invalid speaker') sizeSeq = seq.size(0) tmpData.append(seq) self.seqLabel.append(self.seqLabel[-1] + sizeSeq) speakerSize += sizeSeq del seq self.speakerLabel.append(speakerSize) self.data = torch.cat(tmpData, dim=0) def __getitem__(self, idx): """Get item.""" if idx < 0 or idx >= len(self.data) - self.sizeWindow - 1: print(idx) outData = self.data[idx:(self.sizeWindow + idx)].view(1, -1) label = torch.tensor(self.getSpeakerLabel(idx), dtype=torch.long) return outData, label def getBaseSampler(self, type, batchSize, offset): """Get base sampler.""" if type == "samespeaker": return SameSpeakerSampler(batchSize, self.speakerLabel, self.sizeWindow, offset) if type == "samesequence": return SameSpeakerSampler(batchSize, self.seqLabel, self.sizeWindow, offset) if type == "sequential": return SequentialSampler(len(self.data), self.sizeWindow, offset, batchSize) sampler = UniformAudioSampler(len(self.data), self.sizeWindow, offset) return BatchSampler(sampler, batchSize, True) def getDataLoader(self, batchSize, type, randomOffset, numWorkers=0, onLoop=-1): """Get a batch sampler for the current dataset. Args: - batchSize (int): batch size - groupSize (int): in the case of type in ["samespeaker", "samesequence"] number of items sharing a same label in the group (see AudioBatchSampler) - type (string): type == "samespeaker": grouped sampler speaker-wise type == "samesequence": grouped sampler sequence-wise type == "sequential": sequential sampling else: uniform random sampling of the full audio vector - randomOffset (bool): if True add a random offset to the sampler at the begining of each iteration """ nLoops = len(self.packageIndex) totSize = self.totSize // (self.sizeWindow * batchSize) if onLoop >= 0: self.currentPack = onLoop - 1 self.loadNextPack() nLoops = 1 def samplerCall(): offset = random.randint(0, self.sizeWindow // 2) \ if randomOffset else 0 return self.getBaseSampler(type, batchSize, offset) return AudioLoader(self, samplerCall, nLoops, self.loadNextPack, totSize, numWorkers)
class DataLoader(): def __init__(self, args): self.dir_bin = args.dir_bin line_load_list = self.dir_bin + 'line_load_list.t7' vocab_file = self.dir_bin + 'vocab.t7' assert os.path.isfile(self.dir_bin + 'specM.bin') assert os.path.isfile(self.dir_bin + 'specL.bin') assert os.path.isfile(self.dir_bin + 'text.bin') self.batch_size = args.batch_size self.trunc_size = args.trunc_size self.r_factor = args.r_factor self.dec_out_size = args.dec_out_size self.post_out_size = args.post_out_size self.shuffle_data = True if args.shuffle_data == 1 else False self.iter_per_epoch = None self.is_subbatch_end = True self.curr_split = None self.vocab_size = None self.process = None self.queue = Queue(maxsize=args.load_queue_size) self.n_workers = args.n_workers self.use_gpu = args.use_gpu self.num_gpu = len(args.gpu) if len(args.gpu) > 0 else 1 self.pinned_memory = True if args.pinned_memory == 1 and self.use_gpu else False self.vocab_size = self.get_num_vocab(vocab_file) text_limit = args.text_limit wave_limit = args.wave_limit # col1: idx / col2: wave_length / col3: text_length # col4: offset_M / col5: offset_L / col6: offset_T self.load_list = torch.load(line_load_list) spec_len_list = self.load_list[:, 1].clone() text_len_list = self.load_list[:, 2].clone() # exclude files whose wave length exceeds wave_limit sort_length, sort_idx = spec_len_list.sort() text_len_list = torch.gather(text_len_list, 0, sort_idx) sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list) self.load_list = torch.gather(self.load_list, 0, sort_idx) end_idx = sort_length.le(wave_limit).sum() spec_len_list = sort_length[:end_idx] text_len_list = text_len_list[:end_idx] self.load_list = self.load_list[:end_idx] # exclude files whose text length exceeds text_limit sort_length, sort_idx = text_len_list.sort() spec_len_list = torch.gather(spec_len_list, 0, sort_idx) sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list) self.load_list = torch.gather(self.load_list, 0, sort_idx) end_idx = sort_length.le(text_limit).sum() end_idx = end_idx - (end_idx % self.batch_size) # drop residual data text_len_list = sort_length[:end_idx] spec_len_list = spec_len_list[:end_idx] self.load_list = self.load_list[:end_idx] # sort by wave length _, sort_idx = spec_len_list.sort(0, descending=True) text_len_list = torch.gather(text_len_list, 0, sort_idx) sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list) self.load_list = torch.gather(self.load_list, 0, sort_idx) # sort by text length in each batch (PackedSequence requires it) num_batches_per_epoch = self.load_list.size(0) // self.batch_size text_len_list = text_len_list.view(num_batches_per_epoch, -1) self.load_list = self.load_list.view(num_batches_per_epoch, -1, self.load_list.size(1)) sort_length, sort_idx = text_len_list.sort(1, descending=True) sort_idx = sort_idx.view(num_batches_per_epoch, -1, 1).expand_as(self.load_list) self.load_list = torch.gather(self.load_list, 1, sort_idx) # shuffle while preserving order in a batch if self.shuffle_data: _, sort_idx = torch.randn(num_batches_per_epoch).sort() sort_idx = sort_idx.view(-1, 1, 1).expand_as(self.load_list) self.load_list = torch.gather(self.load_list, 0, sort_idx) # nbpe x N x 6 self.load_list = self.load_list.long() # compute number of iterations needed spec_len_list = spec_len_list.view(num_batches_per_epoch, -1) spec_len_list, _ = spec_len_list.div(self.trunc_size).ceil().max(1) self.iter_per_epoch = int(spec_len_list.sum()) # set split cursor self.split_sizes = { 'train': self.load_list.size(0), 'val': -1, 'test': -1 } self.split_cursor = {'train': 0, 'val': 0, 'test': 0} def next_batch(self, split): T, idx = self.trunc_size, self.split_cursor[split] # seek and load data from raw files if self.is_subbatch_end: self.is_subbatch_end = False self.subbatch_cursor = 0 if self.curr_split != split: self.curr_split = split if self.process is not None: self.process.terminate() self.process = Process(target=self.start_async_loader, args=(split, self.split_cursor[split])) self.process.start() self.len_text, self.len_wave, self.curr_text, self.curr_specM, self.curr_specL = self.queue.get( ) self.split_cursor[split] = (idx + 1) % self.split_sizes[split] self.subbatch_max_len = self.len_wave.max() # Variables to return # +1 to length of y to consider shifting for target y subbatch_len_text = [x for x in self.len_text] subbatch_len_wave = [min(x, T) for x in self.len_wave] x_text = self.curr_text y_specM = self.curr_specM[:, self.subbatch_cursor:self.subbatch_cursor + max(subbatch_len_wave) + 1].contiguous() y_specL = self.curr_specL[:, self.subbatch_cursor:self.subbatch_cursor + max(subbatch_len_wave) + 1].contiguous() if self.use_gpu: if self.pinned_memory: x_text = x_text.pin_memory() y_specM = y_specM.pin_memory() y_specL = y_specL.pin_memory() x_text = x_text.cuda() y_specM = y_specM.cuda() y_specL = y_specL.cuda() # Advance split_cursor or Move on to the next batch if self.subbatch_cursor + T < self.subbatch_max_len: self.subbatch_cursor = self.subbatch_cursor + T self.len_wave.sub_(T).clamp_(min=0) else: self.is_subbatch_end = True # Don't compute for empty batch elements if subbatch_len_wave.count(0) > 0: self.len_wave_mask = [ idx for idx, l in enumerate(subbatch_len_wave) if l > 0 ] self.len_wave_mask = torch.LongTensor(self.len_wave_mask) if self.use_gpu: self.len_wave_mask = self.len_wave_mask.cuda() x_text = torch.index_select(x_text, 0, self.len_wave_mask) y_specM = torch.index_select(y_specM, 0, self.len_wave_mask) y_specL = torch.index_select(y_specL, 0, self.len_wave_mask) subbatch_len_text = [ subbatch_len_text[idx] for idx in self.len_wave_mask ] subbatch_len_wave = [ subbatch_len_wave[idx] for idx in self.len_wave_mask ] else: self.len_wave_mask = None return x_text, y_specM, y_specL, subbatch_len_wave, subbatch_len_text def start_async_loader(self, split, load_start_idx): # load batches to the queue asynchronously since it is a bottle-neck N, r = self.batch_size, self.r_factor load_curr_idx = load_start_idx while True: data_T, data_M, data_L, len_T, len_M = ([None for _ in range(N)] for _ in range(5)) # deploy workers to load data self.pool = Pool(self.n_workers) partial_func = partial(load_data_and_length, self.dir_bin, self.load_list[load_curr_idx]) results = self.pool.map_async(func=partial_func, iterable=range(N)) self.pool.close() self.pool.join() for result in results.get(): data_M[result[0]] = result[1] data_L[result[0]] = result[2] data_T[result[0]] = result[3] len_T[result[0]] = result[4] len_M[result[0]] = result[5] # TODO: output size is not accurate.. // len_text = torch.IntTensor(len_T) len_wave = torch.Tensor(len_M).div(r).ceil().mul( r).int() # consider r_factor curr_text = torch.LongTensor(N, len_text.max()).fill_( 0) # null-padding at tail curr_specM = torch.Tensor(N, len_wave.max() + 1, self.dec_out_size).fill_( 0) # null-padding at tail curr_specL = torch.Tensor(N, len_wave.max() + 1, self.post_out_size).fill_( 0) # null-padding at tail # fill the template tensors for j in range(N): curr_text[j, 0:data_T[j].size(0)].copy_(data_T[j]) curr_specM[j, 0:data_M[j].size(0)].copy_(data_M[j]) curr_specL[j, 0:data_L[j].size(0)].copy_(data_L[j]) self.queue.put( (len_text, len_wave, curr_text, curr_specM, curr_specL)) load_curr_idx = (load_curr_idx + 1) % self.split_sizes[split] def mask_prev_h(self, prev_h): if self.len_wave_mask is not None: if self.use_gpu: self.len_wave_mask = self.len_wave_mask.cuda() h_att, h_dec1, h_dec2 = prev_h h_att = torch.index_select(h_att.data, 1, self.len_wave_mask) # batch idx is h_dec1 = torch.index_select(h_dec1.data, 1, self.len_wave_mask) h_dec2 = torch.index_select(h_dec2.data, 1, self.len_wave_mask) prev_h = (Variable(h_att), Variable(h_dec1), Variable(h_dec2)) else: prev_h = prev_h return prev_h def get_num_vocab(self, vocab_file=None): if self.vocab_size: return self.vocab_size else: vocab_dict = torch.load(vocab_file) return len(vocab_dict) + 1 # +1 to consider null-padding
class SequentialData(Dataset): def __init__(self, path, seqNames, nProcessLoader=50, MAX_SIZE_LOADED=4000000000): """ Args: - path (string): path to the training dataset - seqNames (list): sequences to load - nProcessLoader (int): number of processes to call when loading the data from the disk - MAX_SIZE_LOADED (int): target maximal size of the floating array containing all loaded data. """ self.MAX_SIZE_LOADED = MAX_SIZE_LOADED self.nProcessLoader = nProcessLoader self.dbPath = Path(path) self.seqNames = [self.dbPath / x for _, x in seqNames] self.reload_pool = Pool(nProcessLoader) self.prepare() self.data = [] self.loadNextPack(first=True) self.loadNextPack() def getSeqNames(self): return [str(x) for x in self.seqNames] def clear(self): if 'data' in self.__dict__: del self.data def prepare(self): shuffle(self.seqNames) start_time = time.time() print("Checking length...") allShape = self.reload_pool.map(extractShape, self.seqNames) self.packageIndex, self.totSize = [], 0 start, packageSize = 0, 0 for index, shape in tqdm.tqdm(enumerate(allShape)): packageSize += shape[0] if packageSize * shape[1] > self.MAX_SIZE_LOADED: self.packageIndex.append([start, index]) self.totSize += packageSize start, packageSize = index, 0 if packageSize > 0: self.packageIndex.append([start, len(self.seqNames)]) self.totSize += packageSize print(f"Done, elapsed: {time.time() - start_time:.3f} seconds") print(f'Scanned {len(self.seqNames)} sequences ' f'in {time.time() - start_time:.2f} seconds') print(f"{len(self.packageIndex)} chunks computed") self.currentPack = -1 self.nextPack = 0 def getNPacks(self): return len(self.packageIndex) def loadNextPack(self, first=False): self.clear() if not first: self.currentPack = self.nextPack start_time = time.time() print('Joining pool') self.r.wait() print(f'Joined process, elapsed={time.time()-start_time:.3f} secs') self.nextData = self.r.get() self.parseNextDataBlock() del self.nextData self.nextPack = (self.currentPack + 1) % len(self.packageIndex) seqStart, seqEnd = self.packageIndex[self.nextPack] #if self.nextPack == 0 and len(self.packageIndex) > 1: # self.prepare() self.r = self.reload_pool.map_async(loadFilePool, self.seqNames[seqStart:seqEnd]) def parseNextDataBlock(self): # To accelerate the process a bit self.nextData.sort(key=lambda x: x[0]) tmpData = [] for seqName, seq in self.nextData: tmpData.append(seq) del seq self.data = torch.cat(tmpData, dim=0) def __len__(self): return self.totSize def __getitem__(self, idx): if idx < 0 or idx >= len(self.data): print(idx) outData = self.data[idx].view(1, -1) return outData def getNLoadsPerEpoch(self): return len(self.packageIndex) def getDataLoader(self, batchSize, numWorkers=0): r""" Get a batch sampler for the current dataset. Args: - batchSize (int): batch size """ nLoops = len(self.packageIndex) totSize = self.totSize // batchSize def samplerCall(): sampler = UniformAudioSampler(len(self.data)) return BatchSampler(sampler, batchSize, True) return SequentialLoader(self, samplerCall, nLoops, self.loadNextPack, totSize, numWorkers)