Exemplo n.º 1
0
class AudioBatchData(Dataset):
    def __init__(self,
                 path,
                 sizeWindow,
                 seqNames,
                 phoneLabelsDict,
                 nSpeakers,
                 nProcessLoader=50,
                 MAX_SIZE_LOADED=4000000000):
        """
        Args:
            - path (string): path to the training dataset
            - sizeWindow (int): size of the sliding window
            - seqNames (list): sequences to load
            - phoneLabelsDict (dictionnary): if not None, a dictionnary with the
                                             following entries

                                             "step": size of a labelled window
                                             "$SEQ_NAME": list of phonem labels for
                                             the sequence $SEQ_NAME
           - nSpeakers (int): number of speakers to expect.
           - nProcessLoader (int): number of processes to call when loading the
                                   data from the disk
           - MAX_SIZE_LOADED (int): target maximal size of the floating array
                                    containing all loaded data.
        """
        self.MAX_SIZE_LOADED = MAX_SIZE_LOADED
        self.nProcessLoader = nProcessLoader
        self.dbPath = Path(path)
        self.sizeWindow = sizeWindow
        self.seqNames = [(s, self.dbPath / x) for s, x in seqNames]
        self.reload_pool = Pool(nProcessLoader)

        self.prepare()
        self.speakers = list(range(nSpeakers))
        self.data = []

        self.phoneSize = 0 if phoneLabelsDict is None else \
            phoneLabelsDict["step"]
        self.phoneStep = 0 if phoneLabelsDict is None else \
            self.sizeWindow // self.phoneSize

        self.phoneLabelsDict = deepcopy(phoneLabelsDict)
        self.loadNextPack(first=True)
        self.loadNextPack()
        self.doubleLabels = False

    def resetPhoneLabels(self, newPhoneLabels, step):
        self.phoneSize = step
        self.phoneStep = self.sizeWindow // self.phoneSize
        self.phoneLabelsDict = deepcopy(newPhoneLabels)
        self.loadNextPack()

    def splitSeqTags(seqName):
        path = os.path.normpath(seqName)
        return path.split(os.sep)

    def getSeqNames(self):
        return [str(x[1]) for x in self.seqNames]

    def clear(self):
        if 'data' in self.__dict__:
            del self.data
        if 'speakerLabel' in self.__dict__:
            del self.speakerLabel
        if 'phoneLabels' in self.__dict__:
            del self.phoneLabels
        if 'seqLabel' in self.__dict__:
            del self.seqLabel

    def prepare(self):
        random.shuffle(self.seqNames)
        start_time = time.time()

        print("Checking length...")
        allLength = self.reload_pool.map(extractLength, self.seqNames)

        self.packageIndex, self.totSize = [], 0
        start, packageSize = 0, 0
        for index, length in tqdm.tqdm(enumerate(allLength)):
            packageSize += length
            if packageSize > self.MAX_SIZE_LOADED:
                self.packageIndex.append([start, index])
                self.totSize += packageSize
                start, packageSize = index, 0

        if packageSize > 0:
            self.packageIndex.append([start, len(self.seqNames)])
            self.totSize += packageSize

        print(f"Done, elapsed: {time.time() - start_time:.3f} seconds")
        print(f'Scanned {len(self.seqNames)} sequences '
              f'in {time.time() - start_time:.2f} seconds')
        print(f"{len(self.packageIndex)} chunks computed")
        self.currentPack = -1
        self.nextPack = 0

    def getNPacks(self):
        return len(self.packageIndex)

    def loadNextPack(self, first=False):
        self.clear()
        if not first:
            self.currentPack = self.nextPack
            start_time = time.time()
            print('Joining pool')
            self.r.wait()
            print(f'Joined process, elapsed={time.time()-start_time:.3f} secs')
            self.nextData = self.r.get()
            self.parseNextDataBlock()
            del self.nextData
        self.nextPack = (self.currentPack + 1) % len(self.packageIndex)
        seqStart, seqEnd = self.packageIndex[self.nextPack]
        if self.nextPack == 0 and len(self.packageIndex) > 1:
            self.prepare()
        self.r = self.reload_pool.map_async(loadFile,
                                            self.seqNames[seqStart:seqEnd])

    def parseNextDataBlock(self):

        # Labels
        self.speakerLabel = [0]
        self.seqLabel = [0]
        self.phoneLabels = []
        speakerSize = 0
        indexSpeaker = 0

        # To accelerate the process a bit
        self.nextData.sort(key=lambda x: (x[0], x[1]))
        tmpData = []

        for speaker, seqName, seq in self.nextData:
            while self.speakers[indexSpeaker] < speaker:
                indexSpeaker += 1
                self.speakerLabel.append(speakerSize)
            if self.speakers[indexSpeaker] != speaker:
                raise ValueError(f'{speaker} invalid speaker')

            if self.phoneLabelsDict is not None:
                self.phoneLabels += self.phoneLabelsDict[seqName]
                newSize = len(self.phoneLabelsDict[seqName]) * self.phoneSize
                seq = seq[:newSize]

            sizeSeq = seq.size(0)
            tmpData.append(seq)
            self.seqLabel.append(self.seqLabel[-1] + sizeSeq)
            speakerSize += sizeSeq
            del seq

        self.speakerLabel.append(speakerSize)
        self.data = torch.cat(tmpData, dim=0)

    def getPhonem(self, idx):
        idPhone = idx // self.phoneSize
        return self.phoneLabels[idPhone:(idPhone + self.phoneStep)]

    def getSpeakerLabel(self, idx):
        idSpeaker = next(
            x[0] for x in enumerate(self.speakerLabel) if x[1] > idx) - 1
        return idSpeaker

    def __len__(self):
        return self.totSize // self.sizeWindow

    def __getitem__(self, idx):

        if idx < 0 or idx >= len(self.data) - self.sizeWindow - 1:
            print(idx)

        outData = self.data[idx:(self.sizeWindow + idx)].view(1, -1)
        label = torch.tensor(self.getSpeakerLabel(idx), dtype=torch.long)
        if self.phoneSize > 0:
            label_phone = torch.tensor(self.getPhonem(idx), dtype=torch.long)
            if not self.doubleLabels:
                label = label_phone
        else:
            label_phone = torch.zeros(1)

        if self.doubleLabels:
            return outData, label, label_phone

        return outData, label

    def getNSpeakers(self):
        return len(self.speakers)

    def getNSeqs(self):
        return len(self.seqLabel) - 1

    def getNLoadsPerEpoch(self):
        return len(self.packageIndex)

    def getBaseSampler(self, type, batchSize, offset):
        if type == "samespeaker":
            return SameSpeakerSampler(batchSize, self.speakerLabel,
                                      self.sizeWindow, offset)
        if type == "samesequence":
            return SameSpeakerSampler(batchSize, self.seqLabel,
                                      self.sizeWindow, offset)
        if type == "sequential":
            return SequentialSampler(len(self.data), self.sizeWindow, offset,
                                     batchSize)
        sampler = UniformAudioSampler(len(self.data), self.sizeWindow, offset)
        return BatchSampler(sampler, batchSize, True)

    def getDataLoader(self,
                      batchSize,
                      type,
                      randomOffset,
                      numWorkers=0,
                      onLoop=-1):
        r"""
        Get a batch sampler for the current dataset.
        Args:
            - batchSize (int): batch size
            - groupSize (int): in the case of type in ["samespeaker", "samesequence"]
            number of items sharing a same label in the group
            (see AudioBatchSampler)
            - type (string):
                type == "speaker": grouped sampler speaker-wise
                type == "sequence": grouped sampler sequence-wise
                type == "sequential": sequential sampling
                else: uniform random sampling of the full audio
                vector
            - randomOffset (bool): if True add a random offset to the sampler
                                   at the begining of each iteration
        """
        nLoops = len(self.packageIndex)
        totSize = self.totSize // (self.sizeWindow * batchSize)
        if onLoop >= 0:
            self.currentPack = onLoop - 1
            self.loadNextPack()
            nLoops = 1

        def samplerCall():
            offset = random.randint(0, self.sizeWindow // 2) \
                if randomOffset else 0
            return self.getBaseSampler(type, batchSize, offset)

        return AudioLoader(self, samplerCall, nLoops, self.loadNextPack,
                           totSize, numWorkers)
Exemplo n.º 2
0
    def train(self, n_process=None):
        assert n_process is not None
        pool = Pool(processes=n_process)
        outer_lr_scheduler = PiecewiseSchedule(
            [(0, self.outer_lr),
             (int(self.outer_n_epoch / 2), self.outer_lr * 0.1)],
            outside_value=self.outer_lr * 0.1)
        evoluted_loss = loss_net(self.inner_args['nclass'])
        evoluted_loss = torch.nn.DataParallel(evoluted_loss).cuda()

        evoluted_loss.train()
        for epoch in range(self.outer_n_epoch):
            phi_state_dict = evoluted_loss.state_dict()
            phi_noise = []
            for i in range(self.outer_n_noise):
                noise = dict()
                for key in evoluted_loss.state_dict().keys():
                    noise[key] = self.outer_std * torch.randn_like(
                        evoluted_loss.state_dict()[key]) + phi_state_dict[key]
                phi_noise.append(noise)
            start_time = time.time()

            assert self.outer_n_worker % self.outer_n_noise == 0

            args = ((epoch, i, phi_noise[i], self.target_model,
                     self.train_dataset, self.test_dataset, self.inner_args)
                    for i in range(self.outer_n_noise))
            results = pool.map_async(inner_train, args)
            results = results.get()
            result = [np.mean(r['accuracy']) for r in results]

            print(
                'Epoch %d, mean accuracy: %f, max accuracy: %f(%d), min accuracy: %f(%d)'
                % (epoch, np.mean(result), np.max(result),
                   result.index(np.max(result)), np.min(result),
                   result.index(np.min(result))))

            print('All inner loops completed, returns gathered ({:.2f} sec).'.
                  format(time.time() - start_time))
            print('\n')
            for key in evoluted_loss.state_dict():
                grad = 0
                for i in range(self.outer_n_noise):
                    grad += result[i] * phi_noise[i][key].cpu(
                    )  # - self.outer_l2 * evoluted_loss.state_dict()[key].cpu()
                lr = outer_lr_scheduler.value(epoch)
                adam = Adam(shape=grad.shape, stepsize=lr)
                evoluted_loss.state_dict()[key] -= adam.step(
                    grad / self.outer_n_worker).cuda()

            if self.save_model:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': evoluted_loss.state_dict()
                    }, self.log.path_model)
                self.log.log_tabular('epoch', epoch)
                for i in range(len(result)):
                    self.log.log_tabular('reward %d' % i, result[i])
                self.log.dump_tabular()
        pool.close()
Exemplo n.º 3
0
class LibriSelectionDataset(Dataset):
    """LibriSpeech Selection data from sincnet paper."""
    def __init__(self,
                 sizeWindow=20480,
                 db_wav_root=DB_WAV_ROOT,
                 fps_list=str(),
                 label_path=str(),
                 nSpeakers=-1,
                 n_process_loader=50,
                 MAX_SIZE_LOADED=4000000000):
        """Init.
        
        Args:
            - sizeWindow (int): size of the sliding window
            - db_wav_path (str):
            - fps_list_path (str): 
            - label_path (str):
            - n_process_loader (int):
            - MAX_SIZE_LOADED (int): target maximal size of the floating array
                                    containing all loaded data.
                                    
        """
        self.MAX_SIZE_LOADED = MAX_SIZE_LOADED
        self.n_process_loader = n_process_loader
        self.db_wav_root = Path(db_wav_root)
        self.sizeWindow = sizeWindow
        """Parsing customized to Libri-selection dataset."""
        fps_name_only = get_fps_from_txt(fps_list)
        label_dict = np.load(label_path, allow_pickle=True)[()]
        self.all_labels_fps = [(label_dict[x], Path(db_wav_root) / Path(x))
                               for x in fps_name_only]

        self.reload_pool = Pool(n_process_loader)
        self.prepare(
        )  # Split large number of files into packages, and set {self.currentPack=-1, self.nextPack=0}

        if nSpeakers == -1:
            nSpeakers = len(set(label_dict.values()))
        self.speakers = list(range(nSpeakers))
        self.data = []

        self.loadNextPack(first=True)
        self.loadNextPack()

    def __len__(self):
        """Get length."""
        return self.totSize // self.sizeWindow

    def prepare(self):
        """Prepare."""
        random.shuffle(self.all_labels_fps)
        start_time = time.time()

        print("Checking length...")
        allLength = self.reload_pool.map(extractLength, self.all_labels_fps)

        self.packageIndex, self.totSize = [], 0
        start, packageSize = 0, 0
        for index, length in tqdm.tqdm(enumerate(allLength)):
            packageSize += length
            if packageSize > self.MAX_SIZE_LOADED:
                self.packageIndex.append([start, index])
                self.totSize += packageSize
                start, packageSize = index, 0

        if packageSize > 0:
            self.packageIndex.append([start, len(self.all_labels_fps)])
            self.totSize += packageSize

        print(f"Done, elapsed: {time.time() - start_time:.3f} seconds")
        print(f'Scanned {len(self.all_labels_fps)} sequences '
              f'in {time.time() - start_time:.2f} seconds')
        print(f"{len(self.packageIndex)} chunks computed")
        self.currentPack = -1
        self.nextPack = 0

    def clear(self):
        """Clear."""
        if 'data' in self.__dict__:
            del self.data
        if 'speakerLabel' in self.__dict__:
            del self.speakerLabel
        if 'seqLabel' in self.__dict__:
            del self.seqLabel

    def getNPacks(self):
        """Get N packs."""
        return len(self.packageIndex)

    def getNSeqs(self):
        """Get N seqs."""
        return len(self.seqLabel) - 1

    def getNLoadsPerEpoch(self):
        """Get N loads per epoch."""
        return len(self.packageIndex)

    def getSpeakerLabel(self, idx):
        idSpeaker = next(
            x[0] for x in enumerate(self.speakerLabel) if x[1] > idx) - 1
        return idSpeaker

    def loadNextPack(self, first=False):
        """Load next pack."""
        self.clear()
        if not first:
            self.currentPack = self.nextPack
            start_time = time.time()
            print('Joining pool')
            self.r.wait()
            print(f'Joined process, elapsed={time.time()-start_time:.3f} secs')
            self.nextData = self.r.get()
            self.parseNextDataBlock()
            del self.nextData
        self.nextPack = (self.currentPack + 1) % len(self.packageIndex)
        seqStart, seqEnd = self.packageIndex[self.nextPack]
        if self.nextPack == 0 and len(self.packageIndex) > 1:
            self.prepare()
        """map() blocks until complete, map_async() returns immediately and 
        schedules a callback to be run on the result."""
        self.r = self.reload_pool.map_async(
            loadFile, self.all_labels_fps[seqStart:seqEnd])
        """loadFile: return speaker, seqName, seq"""

    def parseNextDataBlock(self):
        """Parse next data block."""
        # Labels
        self.speakerLabel = [0]
        self.seqLabel = [0]
        speakerSize = 0
        indexSpeaker = 0

        # To accelerate the process a bit
        self.nextData.sort(key=lambda x: (x[0], x[1]))
        """
        nextData[0] = (1243, '4910-14124-0001-1',
                       tensor([-0.0089, -0.0084, -0.0079,  ..., -0.0015, -0.0056,  0.0047]))
        """
        tmpData = []

        for speaker, seqName, seq in self.nextData:
            while self.speakers[indexSpeaker] < speaker:
                indexSpeaker += 1
                self.speakerLabel.append(speakerSize)
            if self.speakers[indexSpeaker] != speaker:
                raise ValueError(f'{speaker} invalid speaker')

            sizeSeq = seq.size(0)
            tmpData.append(seq)
            self.seqLabel.append(self.seqLabel[-1] + sizeSeq)
            speakerSize += sizeSeq
            del seq

        self.speakerLabel.append(speakerSize)
        self.data = torch.cat(tmpData, dim=0)

    def __getitem__(self, idx):
        """Get item."""
        if idx < 0 or idx >= len(self.data) - self.sizeWindow - 1:
            print(idx)

        outData = self.data[idx:(self.sizeWindow + idx)].view(1, -1)
        label = torch.tensor(self.getSpeakerLabel(idx), dtype=torch.long)

        return outData, label

    def getBaseSampler(self, type, batchSize, offset):
        """Get base sampler."""
        if type == "samespeaker":
            return SameSpeakerSampler(batchSize, self.speakerLabel,
                                      self.sizeWindow, offset)
        if type == "samesequence":
            return SameSpeakerSampler(batchSize, self.seqLabel,
                                      self.sizeWindow, offset)
        if type == "sequential":
            return SequentialSampler(len(self.data), self.sizeWindow, offset,
                                     batchSize)
        sampler = UniformAudioSampler(len(self.data), self.sizeWindow, offset)
        return BatchSampler(sampler, batchSize, True)

    def getDataLoader(self,
                      batchSize,
                      type,
                      randomOffset,
                      numWorkers=0,
                      onLoop=-1):
        """Get a batch sampler for the current dataset.
        
        Args:
            - batchSize (int): batch size
            - groupSize (int): in the case of type in ["samespeaker", "samesequence"]
            number of items sharing a same label in the group
            (see AudioBatchSampler)
            - type (string):
                type == "samespeaker": grouped sampler speaker-wise
                type == "samesequence": grouped sampler sequence-wise
                type == "sequential": sequential sampling
                else: uniform random sampling of the full audio
                vector
            - randomOffset (bool): if True add a random offset to the sampler
                                   at the begining of each iteration
                                   
        """
        nLoops = len(self.packageIndex)
        totSize = self.totSize // (self.sizeWindow * batchSize)
        if onLoop >= 0:
            self.currentPack = onLoop - 1
            self.loadNextPack()
            nLoops = 1

        def samplerCall():
            offset = random.randint(0, self.sizeWindow // 2) \
                if randomOffset else 0
            return self.getBaseSampler(type, batchSize, offset)

        return AudioLoader(self, samplerCall, nLoops, self.loadNextPack,
                           totSize, numWorkers)
Exemplo n.º 4
0
class DataLoader():
    def __init__(self, args):
        self.dir_bin = args.dir_bin
        line_load_list = self.dir_bin + 'line_load_list.t7'
        vocab_file = self.dir_bin + 'vocab.t7'
        assert os.path.isfile(self.dir_bin + 'specM.bin')
        assert os.path.isfile(self.dir_bin + 'specL.bin')
        assert os.path.isfile(self.dir_bin + 'text.bin')

        self.batch_size = args.batch_size
        self.trunc_size = args.trunc_size
        self.r_factor = args.r_factor
        self.dec_out_size = args.dec_out_size
        self.post_out_size = args.post_out_size
        self.shuffle_data = True if args.shuffle_data == 1 else False
        self.iter_per_epoch = None
        self.is_subbatch_end = True
        self.curr_split = None
        self.vocab_size = None

        self.process = None
        self.queue = Queue(maxsize=args.load_queue_size)
        self.n_workers = args.n_workers

        self.use_gpu = args.use_gpu
        self.num_gpu = len(args.gpu) if len(args.gpu) > 0 else 1
        self.pinned_memory = True if args.pinned_memory == 1 and self.use_gpu else False

        self.vocab_size = self.get_num_vocab(vocab_file)
        text_limit = args.text_limit
        wave_limit = args.wave_limit

        # col1: idx / col2: wave_length / col3: text_length
        # col4: offset_M / col5: offset_L / col6: offset_T
        self.load_list = torch.load(line_load_list)
        spec_len_list = self.load_list[:, 1].clone()
        text_len_list = self.load_list[:, 2].clone()

        # exclude files whose wave length exceeds wave_limit
        sort_length, sort_idx = spec_len_list.sort()
        text_len_list = torch.gather(text_len_list, 0, sort_idx)
        sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list)
        self.load_list = torch.gather(self.load_list, 0, sort_idx)

        end_idx = sort_length.le(wave_limit).sum()
        spec_len_list = sort_length[:end_idx]
        text_len_list = text_len_list[:end_idx]
        self.load_list = self.load_list[:end_idx]

        # exclude files whose text length exceeds text_limit
        sort_length, sort_idx = text_len_list.sort()
        spec_len_list = torch.gather(spec_len_list, 0, sort_idx)
        sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list)
        self.load_list = torch.gather(self.load_list, 0, sort_idx)

        end_idx = sort_length.le(text_limit).sum()
        end_idx = end_idx - (end_idx % self.batch_size)  # drop residual data
        text_len_list = sort_length[:end_idx]
        spec_len_list = spec_len_list[:end_idx]
        self.load_list = self.load_list[:end_idx]

        # sort by wave length
        _, sort_idx = spec_len_list.sort(0, descending=True)
        text_len_list = torch.gather(text_len_list, 0, sort_idx)
        sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list)
        self.load_list = torch.gather(self.load_list, 0, sort_idx)

        # sort by text length in each batch (PackedSequence requires it)
        num_batches_per_epoch = self.load_list.size(0) // self.batch_size
        text_len_list = text_len_list.view(num_batches_per_epoch, -1)
        self.load_list = self.load_list.view(num_batches_per_epoch, -1,
                                             self.load_list.size(1))
        sort_length, sort_idx = text_len_list.sort(1, descending=True)
        sort_idx = sort_idx.view(num_batches_per_epoch, -1,
                                 1).expand_as(self.load_list)
        self.load_list = torch.gather(self.load_list, 1, sort_idx)

        # shuffle while preserving order in a batch
        if self.shuffle_data:
            _, sort_idx = torch.randn(num_batches_per_epoch).sort()
            sort_idx = sort_idx.view(-1, 1, 1).expand_as(self.load_list)
            self.load_list = torch.gather(self.load_list, 0,
                                          sort_idx)  # nbpe x N x 6

        self.load_list = self.load_list.long()

        # compute number of iterations needed
        spec_len_list = spec_len_list.view(num_batches_per_epoch, -1)
        spec_len_list, _ = spec_len_list.div(self.trunc_size).ceil().max(1)
        self.iter_per_epoch = int(spec_len_list.sum())

        # set split cursor
        self.split_sizes = {
            'train': self.load_list.size(0),
            'val': -1,
            'test': -1
        }
        self.split_cursor = {'train': 0, 'val': 0, 'test': 0}

    def next_batch(self, split):
        T, idx = self.trunc_size, self.split_cursor[split]

        # seek and load data from raw files
        if self.is_subbatch_end:
            self.is_subbatch_end = False
            self.subbatch_cursor = 0

            if self.curr_split != split:
                self.curr_split = split
                if self.process is not None:
                    self.process.terminate()
                self.process = Process(target=self.start_async_loader,
                                       args=(split, self.split_cursor[split]))
                self.process.start()

            self.len_text, self.len_wave, self.curr_text, self.curr_specM, self.curr_specL = self.queue.get(
            )
            self.split_cursor[split] = (idx + 1) % self.split_sizes[split]
            self.subbatch_max_len = self.len_wave.max()

        # Variables to return
        # +1 to length of y to consider shifting for target y
        subbatch_len_text = [x for x in self.len_text]
        subbatch_len_wave = [min(x, T) for x in self.len_wave]
        x_text = self.curr_text
        y_specM = self.curr_specM[:,
                                  self.subbatch_cursor:self.subbatch_cursor +
                                  max(subbatch_len_wave) + 1].contiguous()
        y_specL = self.curr_specL[:,
                                  self.subbatch_cursor:self.subbatch_cursor +
                                  max(subbatch_len_wave) + 1].contiguous()

        if self.use_gpu:
            if self.pinned_memory:
                x_text = x_text.pin_memory()
                y_specM = y_specM.pin_memory()
                y_specL = y_specL.pin_memory()

            x_text = x_text.cuda()
            y_specM = y_specM.cuda()
            y_specL = y_specL.cuda()

        # Advance split_cursor or Move on to the next batch
        if self.subbatch_cursor + T < self.subbatch_max_len:
            self.subbatch_cursor = self.subbatch_cursor + T
            self.len_wave.sub_(T).clamp_(min=0)
        else:
            self.is_subbatch_end = True

        # Don't compute for empty batch elements
        if subbatch_len_wave.count(0) > 0:
            self.len_wave_mask = [
                idx for idx, l in enumerate(subbatch_len_wave) if l > 0
            ]
            self.len_wave_mask = torch.LongTensor(self.len_wave_mask)
            if self.use_gpu:
                self.len_wave_mask = self.len_wave_mask.cuda()

            x_text = torch.index_select(x_text, 0, self.len_wave_mask)
            y_specM = torch.index_select(y_specM, 0, self.len_wave_mask)
            y_specL = torch.index_select(y_specL, 0, self.len_wave_mask)
            subbatch_len_text = [
                subbatch_len_text[idx] for idx in self.len_wave_mask
            ]
            subbatch_len_wave = [
                subbatch_len_wave[idx] for idx in self.len_wave_mask
            ]
        else:
            self.len_wave_mask = None

        return x_text, y_specM, y_specL, subbatch_len_wave, subbatch_len_text

    def start_async_loader(self, split, load_start_idx):
        # load batches to the queue asynchronously since it is a bottle-neck
        N, r = self.batch_size, self.r_factor
        load_curr_idx = load_start_idx

        while True:
            data_T, data_M, data_L, len_T, len_M = ([None for _ in range(N)]
                                                    for _ in range(5))
            # deploy workers to load data
            self.pool = Pool(self.n_workers)
            partial_func = partial(load_data_and_length, self.dir_bin,
                                   self.load_list[load_curr_idx])
            results = self.pool.map_async(func=partial_func, iterable=range(N))
            self.pool.close()
            self.pool.join()

            for result in results.get():
                data_M[result[0]] = result[1]
                data_L[result[0]] = result[2]
                data_T[result[0]] = result[3]
                len_T[result[0]] = result[4]
                len_M[result[0]] = result[5]

            # TODO: output size is not accurate.. //
            len_text = torch.IntTensor(len_T)
            len_wave = torch.Tensor(len_M).div(r).ceil().mul(
                r).int()  # consider r_factor
            curr_text = torch.LongTensor(N, len_text.max()).fill_(
                0)  # null-padding at tail
            curr_specM = torch.Tensor(N,
                                      len_wave.max() + 1,
                                      self.dec_out_size).fill_(
                                          0)  # null-padding at tail
            curr_specL = torch.Tensor(N,
                                      len_wave.max() + 1,
                                      self.post_out_size).fill_(
                                          0)  # null-padding at tail

            # fill the template tensors
            for j in range(N):
                curr_text[j, 0:data_T[j].size(0)].copy_(data_T[j])
                curr_specM[j, 0:data_M[j].size(0)].copy_(data_M[j])
                curr_specL[j, 0:data_L[j].size(0)].copy_(data_L[j])

            self.queue.put(
                (len_text, len_wave, curr_text, curr_specM, curr_specL))
            load_curr_idx = (load_curr_idx + 1) % self.split_sizes[split]

    def mask_prev_h(self, prev_h):
        if self.len_wave_mask is not None:
            if self.use_gpu:
                self.len_wave_mask = self.len_wave_mask.cuda()

            h_att, h_dec1, h_dec2 = prev_h
            h_att = torch.index_select(h_att.data, 1,
                                       self.len_wave_mask)  # batch idx is
            h_dec1 = torch.index_select(h_dec1.data, 1, self.len_wave_mask)
            h_dec2 = torch.index_select(h_dec2.data, 1, self.len_wave_mask)
            prev_h = (Variable(h_att), Variable(h_dec1), Variable(h_dec2))
        else:
            prev_h = prev_h

        return prev_h

    def get_num_vocab(self, vocab_file=None):
        if self.vocab_size:
            return self.vocab_size
        else:
            vocab_dict = torch.load(vocab_file)
            return len(vocab_dict) + 1  # +1 to consider null-padding
Exemplo n.º 5
0
class SequentialData(Dataset):
    def __init__(self,
                 path,
                 seqNames,
                 nProcessLoader=50,
                 MAX_SIZE_LOADED=4000000000):
        """
        Args:
            - path (string): path to the training dataset
            - seqNames (list): sequences to load
            - nProcessLoader (int): number of processes to call when loading the
                                    data from the disk
            - MAX_SIZE_LOADED (int): target maximal size of the floating array
                                     containing all loaded data.
        """
        self.MAX_SIZE_LOADED = MAX_SIZE_LOADED
        self.nProcessLoader = nProcessLoader
        self.dbPath = Path(path)
        self.seqNames = [self.dbPath / x for _, x in seqNames]
        self.reload_pool = Pool(nProcessLoader)

        self.prepare()
        self.data = []

        self.loadNextPack(first=True)
        self.loadNextPack()

    def getSeqNames(self):
        return [str(x) for x in self.seqNames]

    def clear(self):
        if 'data' in self.__dict__:
            del self.data

    def prepare(self):
        shuffle(self.seqNames)
        start_time = time.time()

        print("Checking length...")
        allShape = self.reload_pool.map(extractShape, self.seqNames)

        self.packageIndex, self.totSize = [], 0
        start, packageSize = 0, 0
        for index, shape in tqdm.tqdm(enumerate(allShape)):
            packageSize += shape[0]
            if packageSize * shape[1] > self.MAX_SIZE_LOADED:
                self.packageIndex.append([start, index])
                self.totSize += packageSize
                start, packageSize = index, 0

        if packageSize > 0:
            self.packageIndex.append([start, len(self.seqNames)])
            self.totSize += packageSize

        print(f"Done, elapsed: {time.time() - start_time:.3f} seconds")
        print(f'Scanned {len(self.seqNames)} sequences '
              f'in {time.time() - start_time:.2f} seconds')
        print(f"{len(self.packageIndex)} chunks computed")
        self.currentPack = -1
        self.nextPack = 0

    def getNPacks(self):
        return len(self.packageIndex)

    def loadNextPack(self, first=False):
        self.clear()
        if not first:
            self.currentPack = self.nextPack
            start_time = time.time()
            print('Joining pool')
            self.r.wait()
            print(f'Joined process, elapsed={time.time()-start_time:.3f} secs')
            self.nextData = self.r.get()
            self.parseNextDataBlock()
            del self.nextData
        self.nextPack = (self.currentPack + 1) % len(self.packageIndex)
        seqStart, seqEnd = self.packageIndex[self.nextPack]
        #if self.nextPack == 0 and len(self.packageIndex) > 1:
        #    self.prepare()
        self.r = self.reload_pool.map_async(loadFilePool,
                                            self.seqNames[seqStart:seqEnd])

    def parseNextDataBlock(self):
        # To accelerate the process a bit
        self.nextData.sort(key=lambda x: x[0])
        tmpData = []

        for seqName, seq in self.nextData:
            tmpData.append(seq)
            del seq

        self.data = torch.cat(tmpData, dim=0)

    def __len__(self):
        return self.totSize

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.data):
            print(idx)
        outData = self.data[idx].view(1, -1)
        return outData

    def getNLoadsPerEpoch(self):
        return len(self.packageIndex)

    def getDataLoader(self, batchSize, numWorkers=0):
        r"""
        Get a batch sampler for the current dataset.
        Args:
            - batchSize (int): batch size
        """
        nLoops = len(self.packageIndex)
        totSize = self.totSize // batchSize

        def samplerCall():
            sampler = UniformAudioSampler(len(self.data))
            return BatchSampler(sampler, batchSize, True)

        return SequentialLoader(self, samplerCall, nLoops, self.loadNextPack,
                                totSize, numWorkers)