def split_process_dataset(args, transforms): if args.dataset == 'ljspeech': data = LJSPEECH(root=args.file_path, download=False) val_length = int(len(data) * args.val_ratio) lengths = [len(data) - val_length, val_length] train_dataset, val_dataset = random_split(data, lengths) elif args.dataset == 'libritts': train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False) val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False) else: raise ValueError( f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}" ) train_dataset = Processed(train_dataset, transforms) val_dataset = Processed(val_dataset, transforms) train_dataset = MapMemoryCache(train_dataset) val_dataset = MapMemoryCache(val_dataset) return train_dataset, val_dataset
def process(args): out_root = Path(args.output_data_root).absolute() out_root.mkdir(parents=True, exist_ok=True) # Generate TSV manifest print("Generating manifest...") # following FastSpeech's splits dataset = LJSPEECH(out_root.as_posix(), download=True) id_to_split = {} for x in dataset._flist: id_ = x[0] speaker = id_.split("-")[0] id_to_split[id_] = { "LJ001": "test", "LJ002": "test", "LJ003": "dev" }.get(speaker, "train") manifest_by_split = {split: defaultdict(list) for split in SPLITS} progress = tqdm(enumerate(dataset), total=len(dataset)) for i, (waveform, _, utt, normalized_utt) in progress: sample_id = dataset._flist[i][0] split = id_to_split[sample_id] manifest_by_split[split]["id"].append(sample_id) audio_path = f"{dataset._path}/{sample_id}.wav" manifest_by_split[split]["audio"].append(audio_path) manifest_by_split[split]["n_frames"].append(len(waveform[0])) manifest_by_split[split]["tgt_text"].append(normalized_utt) manifest_by_split[split]["speaker"].append("ljspeech") manifest_by_split[split]["src_text"].append(utt) manifest_root = Path(args.output_manifest_root).absolute() manifest_root.mkdir(parents=True, exist_ok=True) for split in SPLITS: save_df_to_tsv(pd.DataFrame.from_dict(manifest_by_split[split]), manifest_root / f"{split}.audio.tsv")
def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0] mel_kwargs = { 'sample_rate': sample_rate, 'n_fft': 2048, 'f_min': 40., 'n_mels': 80, 'win_length': 1100, 'hop_length': 275, 'mel_scale': 'slaney', 'norm': 'slaney', 'power': 1, } transforms = torch.nn.Sequential( MelSpectrogram(**mel_kwargs), NormalizeDB(min_level_db=-100, normalization=True), ) mel_specgram = transforms(waveform) wavernn_model = wavernn(args.checkpoint_name).eval().to(device) wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model) if args.jit: wavernn_inference_model = torch.jit.script(wavernn_inference_model) with torch.no_grad(): output = wavernn_inference_model(mel_specgram.to(device), mulaw=(not args.no_mulaw), batched=(not args.no_batch_inference), timesteps=args.batch_timesteps, overlap=args.batch_overlap,) torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
def split_process_ljspeech(args, transforms): data = LJSPEECH(root=args.file_path, download=False) val_length = int(len(data) * args.val_ratio) lengths = [len(data) - val_length, val_length] train_dataset, val_dataset = random_split(data, lengths) train_dataset = Processed(train_dataset, transforms) val_dataset = Processed(val_dataset, transforms) train_dataset = MapMemoryCache(train_dataset) val_dataset = MapMemoryCache(val_dataset) return train_dataset, val_dataset
def main(args): logger.info("Start time: {}".format(str(datetime.now()))) torch.manual_seed(0) random.seed(0) if args.master_addr is not None: os.environ['MASTER_ADDR'] = args.master_addr elif 'MASTER_ADDR' not in os.environ: os.environ['MASTER_ADDR'] = 'localhost' if args.master_port is not None: os.environ['MASTER_PORT'] = args.master_port elif 'MASTER_PORT' not in os.environ: os.environ['MASTER_PORT'] = '17778' device_counts = torch.cuda.device_count() logger.info(f"# available GPUs: {device_counts}") # download dataset is not already downloaded if args.dataset == 'ljspeech': if not os.path.exists(os.path.join(args.dataset_path, 'LJSpeech-1.1')): from torchaudio.datasets import LJSPEECH LJSPEECH(root=args.dataset_path, download=True) if device_counts == 1: train(0, 1, args) else: mp.spawn(train, args=( device_counts, args, ), nprocs=device_counts, join=True) logger.info(f"End time: {datetime.now()}")
def split_process_dataset( dataset: str, file_path: str, val_ratio: float, transforms: Callable, text_preprocessor: Callable[[str], List[int]], ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]: """Returns the Training and validation datasets. Args: dataset (str): The dataset to use. Avaliable options: [`'ljspeech'`] file_path (str): Path to the data. val_ratio (float): Path to the data. transforms (callable): A function/transform that takes in a waveform and returns a transformed waveform (mel spectrogram in this example). text_preprocess (callable): A function that takes in a string and returns a list of integers representing each of the symbol in the string. Returns: train_dataset (`torch.utils.data.Dataset`): The training set. val_dataset (`torch.utils.data.Dataset`): The validation set. """ if dataset == 'ljspeech': data = LJSPEECH(root=file_path, download=False) val_length = int(len(data) * val_ratio) lengths = [len(data) - val_length, val_length] train_dataset, val_dataset = random_split(data, lengths) else: raise ValueError(f"Expected datasets: `ljspeech`, but found {dataset}") train_dataset = Processed(train_dataset, transforms, text_preprocessor) val_dataset = Processed(val_dataset, transforms, text_preprocessor) train_dataset = MapMemoryCache(train_dataset) val_dataset = MapMemoryCache(val_dataset) return train_dataset, val_dataset