Пример #1
0
def split_process_dataset(args, transforms):
    if args.dataset == 'ljspeech':
        data = LJSPEECH(root=args.file_path, download=False)

        val_length = int(len(data) * args.val_ratio)
        lengths = [len(data) - val_length, val_length]
        train_dataset, val_dataset = random_split(data, lengths)

    elif args.dataset == 'libritts':
        train_dataset = LIBRITTS(root=args.file_path,
                                 url='train-clean-100',
                                 download=False)
        val_dataset = LIBRITTS(root=args.file_path,
                               url='dev-clean',
                               download=False)

    else:
        raise ValueError(
            f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}"
        )

    train_dataset = Processed(train_dataset, transforms)
    val_dataset = Processed(val_dataset, transforms)

    train_dataset = MapMemoryCache(train_dataset)
    val_dataset = MapMemoryCache(val_dataset)

    return train_dataset, val_dataset
Пример #2
0
def process(args):
    out_root = Path(args.output_data_root).absolute()
    out_root.mkdir(parents=True, exist_ok=True)

    # Generate TSV manifest
    print("Generating manifest...")
    # following FastSpeech's splits
    dataset = LJSPEECH(out_root.as_posix(), download=True)
    id_to_split = {}
    for x in dataset._flist:
        id_ = x[0]
        speaker = id_.split("-")[0]
        id_to_split[id_] = {
            "LJ001": "test",
            "LJ002": "test",
            "LJ003": "dev"
        }.get(speaker, "train")
    manifest_by_split = {split: defaultdict(list) for split in SPLITS}
    progress = tqdm(enumerate(dataset), total=len(dataset))
    for i, (waveform, _, utt, normalized_utt) in progress:
        sample_id = dataset._flist[i][0]
        split = id_to_split[sample_id]
        manifest_by_split[split]["id"].append(sample_id)
        audio_path = f"{dataset._path}/{sample_id}.wav"
        manifest_by_split[split]["audio"].append(audio_path)
        manifest_by_split[split]["n_frames"].append(len(waveform[0]))
        manifest_by_split[split]["tgt_text"].append(normalized_utt)
        manifest_by_split[split]["speaker"].append("ljspeech")
        manifest_by_split[split]["src_text"].append(utt)

    manifest_root = Path(args.output_manifest_root).absolute()
    manifest_root.mkdir(parents=True, exist_ok=True)
    for split in SPLITS:
        save_df_to_tsv(pd.DataFrame.from_dict(manifest_by_split[split]),
                       manifest_root / f"{split}.audio.tsv")
Пример #3
0
def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0]

    mel_kwargs = {
        'sample_rate': sample_rate,
        'n_fft': 2048,
        'f_min': 40.,
        'n_mels': 80,
        'win_length': 1100,
        'hop_length': 275,
        'mel_scale': 'slaney',
        'norm': 'slaney',
        'power': 1,
    }
    transforms = torch.nn.Sequential(
        MelSpectrogram(**mel_kwargs),
        NormalizeDB(min_level_db=-100, normalization=True),
    )
    mel_specgram = transforms(waveform)

    wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
    wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)

    if args.jit:
        wavernn_inference_model = torch.jit.script(wavernn_inference_model)

    with torch.no_grad():
        output = wavernn_inference_model(mel_specgram.to(device),
                                         mulaw=(not args.no_mulaw),
                                         batched=(not args.no_batch_inference),
                                         timesteps=args.batch_timesteps,
                                         overlap=args.batch_overlap,)

    torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
Пример #4
0
def split_process_ljspeech(args, transforms):
    data = LJSPEECH(root=args.file_path, download=False)

    val_length = int(len(data) * args.val_ratio)
    lengths = [len(data) - val_length, val_length]
    train_dataset, val_dataset = random_split(data, lengths)

    train_dataset = Processed(train_dataset, transforms)
    val_dataset = Processed(val_dataset, transforms)

    train_dataset = MapMemoryCache(train_dataset)
    val_dataset = MapMemoryCache(val_dataset)

    return train_dataset, val_dataset
Пример #5
0
def main(args):
    logger.info("Start time: {}".format(str(datetime.now())))

    torch.manual_seed(0)
    random.seed(0)

    if args.master_addr is not None:
        os.environ['MASTER_ADDR'] = args.master_addr
    elif 'MASTER_ADDR' not in os.environ:
        os.environ['MASTER_ADDR'] = 'localhost'

    if args.master_port is not None:
        os.environ['MASTER_PORT'] = args.master_port
    elif 'MASTER_PORT' not in os.environ:
        os.environ['MASTER_PORT'] = '17778'

    device_counts = torch.cuda.device_count()

    logger.info(f"# available GPUs: {device_counts}")

    # download dataset is not already downloaded
    if args.dataset == 'ljspeech':
        if not os.path.exists(os.path.join(args.dataset_path, 'LJSpeech-1.1')):
            from torchaudio.datasets import LJSPEECH
            LJSPEECH(root=args.dataset_path, download=True)

    if device_counts == 1:
        train(0, 1, args)
    else:
        mp.spawn(train,
                 args=(
                     device_counts,
                     args,
                 ),
                 nprocs=device_counts,
                 join=True)

    logger.info(f"End time: {datetime.now()}")
Пример #6
0
def split_process_dataset(
    dataset: str,
    file_path: str,
    val_ratio: float,
    transforms: Callable,
    text_preprocessor: Callable[[str], List[int]],
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
    """Returns the Training and validation datasets.

    Args:
        dataset (str): The dataset to use. Avaliable options: [`'ljspeech'`]
        file_path (str): Path to the data.
        val_ratio (float): Path to the data.
        transforms (callable): A function/transform that takes in a waveform and
            returns a transformed waveform (mel spectrogram in this example).
        text_preprocess (callable): A function that takes in a string and
            returns a list of integers representing each of the symbol in the string.

    Returns:
        train_dataset (`torch.utils.data.Dataset`): The training set.
        val_dataset (`torch.utils.data.Dataset`): The validation set.
    """
    if dataset == 'ljspeech':
        data = LJSPEECH(root=file_path, download=False)

        val_length = int(len(data) * val_ratio)
        lengths = [len(data) - val_length, val_length]
        train_dataset, val_dataset = random_split(data, lengths)
    else:
        raise ValueError(f"Expected datasets: `ljspeech`, but found {dataset}")

    train_dataset = Processed(train_dataset, transforms, text_preprocessor)
    val_dataset = Processed(val_dataset, transforms, text_preprocessor)

    train_dataset = MapMemoryCache(train_dataset)
    val_dataset = MapMemoryCache(val_dataset)

    return train_dataset, val_dataset