class TextMelLoader(torch.utils.data.Dataset): """ 1) loads audio,text pairs 2) normalizes text and converts them to sequences of one-hot vectors 3) computes mel-spectrograms from audio files. """ def __init__(self, audiopaths_and_text, hparams): self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) self.text_cleaners = hparams.text_cleaners self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = hparams.load_mel_from_disk self.stft = layers.TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax) self.ct = CustomText(hparams.level) random.seed(hparams.seed) random.shuffle(self.audiopaths_and_text) def get_mel_text_pair(self, audiopath_and_text): # separate filename and text audiopath, text = audiopath_and_text[0], audiopath_and_text[1] text = self.get_text(text) mel = self.get_mel(audiopath) return (text, mel) def get_mel(self, filename): if not self.load_mel_from_disk: audio, sampling_rate = load_wav_to_torch(filename) if sampling_rate != self.stft.sampling_rate: raise ValueError("{} {} SR doesn't match target {} SR".format( sampling_rate, self.stft.sampling_rate)) audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) melspec = self.stft.mel_spectrogram(audio_norm) melspec = torch.squeeze(melspec, 0) else: melspec = torch.from_numpy(np.load(filename)) assert melspec.size(0) == self.stft.n_mel_channels, ( 'Mel dimension mismatch: given {}, expected {}'.format( melspec.size(0), self.stft.n_mel_channels)) return melspec def get_text(self, text): text_norm = torch.IntTensor( self.ct.text_to_sequence(text, self.text_cleaners)) return text_norm def __getitem__(self, index): return self.get_mel_text_pair(self.audiopaths_and_text[index]) def __len__(self): return len(self.audiopaths_and_text)
class Dataset(Dataset): def __init__(self, filename, preprocess_config, train_config, model_config, sort=False, drop_last=False): self.dataset_name = preprocess_config["dataset"] self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] self.cleaners = preprocess_config["preprocessing"]["text"][ "text_cleaners"] self.batch_size = train_config["optimizer"]["batch_size"] self.ct = CustomText(model_config['level']) # getattr(self, 'text_to_sequence', ct.text_to_sequence) self.basename, self.speaker, self.text, self.raw_text = self.process_meta( filename) with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: self.speaker_map = json.load(f) self.sort = sort self.drop_last = drop_last def __len__(self): return len(self.text) def __getitem__(self, idx): basename = self.basename[idx] speaker = self.speaker[idx] speaker_id = self.speaker_map[speaker] raw_text = self.raw_text[idx] phone = np.array( self.ct.text_to_sequence(self.text[idx], self.cleaners)) mel_path = os.path.join( self.preprocessed_path, "mel", "{}-mel-{}.npy".format(speaker, basename), ) mel = np.load(mel_path) pitch_path = os.path.join( self.preprocessed_path, "pitch", "{}-pitch-{}.npy".format(speaker, basename), ) pitch = np.load(pitch_path) energy_path = os.path.join( self.preprocessed_path, "energy", "{}-energy-{}.npy".format(speaker, basename), ) energy = np.load(energy_path) duration_path = os.path.join( self.preprocessed_path, "duration", "{}-duration-{}.npy".format(speaker, basename), ) duration = np.load(duration_path) sample = { "id": basename, "speaker": speaker_id, "text": phone, "raw_text": raw_text, "mel": mel, "pitch": pitch, "energy": energy, "duration": duration, } return sample def process_meta(self, filename): with open(os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8") as f: name = [] speaker = [] text = [] raw_text = [] for line in f.readlines(): n, s, t, r = line.strip("\n").split("|") name.append(n) speaker.append(s) text.append(t) raw_text.append(r) return name, speaker, text, raw_text def reprocess(self, data, idxs): ids = [data[idx]["id"] for idx in idxs] speakers = [data[idx]["speaker"] for idx in idxs] texts = [data[idx]["text"] for idx in idxs] raw_texts = [data[idx]["raw_text"] for idx in idxs] mels = [data[idx]["mel"] for idx in idxs] pitches = [data[idx]["pitch"] for idx in idxs] energies = [data[idx]["energy"] for idx in idxs] durations = [data[idx]["duration"] for idx in idxs] text_lens = np.array([text.shape[0] for text in texts]) mel_lens = np.array([mel.shape[0] for mel in mels]) speakers = np.array(speakers) texts = pad_1D(texts) mels = pad_2D(mels) pitches = pad_1D(pitches) energies = pad_1D(energies) durations = pad_1D(durations) return ( ids, raw_texts, speakers, texts, text_lens, max(text_lens), mels, mel_lens, max(mel_lens), pitches, energies, durations, ) def collate_fn(self, data): data_size = len(data) if self.sort: len_arr = np.array([d["text"].shape[0] for d in data]) idx_arr = np.argsort(-len_arr) else: idx_arr = np.arange(data_size) tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size):] idx_arr = idx_arr[:len(idx_arr) - (len(idx_arr) % self.batch_size)] idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() if not self.drop_last and len(tail) > 0: idx_arr += [tail.tolist()] output = list() for idx in idx_arr: output.append(self.reprocess(data, idx)) return output