Ejemplo n.º 1
0
def get_datasets(
        meta_dir: str,
        batch_size: int,
        num_workers: int,
        fix_len: int = 0,
        skip_audio: bool = False) -> Tuple[SpeechDataLoader, SpeechDataLoader]:
    # TODO: update this function in general
    assert os.path.isdir(meta_dir), '{} is not valid directory path!'

    train_file, valid_file = MaestroMeta.frame_file_names[1:]

    # load meta file
    train_meta = MaestroMeta(os.path.join(meta_dir, train_file))
    valid_meta = MaestroMeta(os.path.join(meta_dir, valid_file))

    # create dataset
    train_dataset = SpeechDataset(train_meta,
                                  fix_len=fix_len,
                                  skip_audio=skip_audio)
    valid_dataset = SpeechDataset(valid_meta,
                                  fix_len=fix_len,
                                  skip_audio=skip_audio)

    # create data loader
    train_loader = SpeechDataLoader(train_dataset,
                                    batch_size=batch_size,
                                    num_workers=num_workers)
    valid_loader = SpeechDataLoader(valid_dataset,
                                    batch_size=batch_size,
                                    num_workers=num_workers)

    return train_loader, valid_loader
Ejemplo n.º 2
0
def get_datasets(
        meta_dir: str,
        batch_size: int,
        num_workers: int,
        fix_len: int = 0,
        audio_mask: bool = False) -> Tuple[SpeechDataLoader, SpeechDataLoader]:
    assert os.path.isdir(meta_dir), '{} is not valid directory path!'

    train_file, valid_file = MedleyDBMeta.frame_file_names[1:]

    # load meta file
    train_meta = MedleyDBMeta(os.path.join(meta_dir, train_file))
    valid_meta = MedleyDBMeta(os.path.join(meta_dir, valid_file))

    # create dataset
    train_dataset = SpeechDataset(train_meta,
                                  fix_len=fix_len,
                                  audio_mask=audio_mask)
    valid_dataset = SpeechDataset(valid_meta,
                                  fix_len=fix_len,
                                  audio_mask=audio_mask)

    # create data loader
    train_loader = SpeechDataLoader(train_dataset,
                                    batch_size=batch_size,
                                    num_workers=num_workers,
                                    is_bucket=False)
    valid_loader = SpeechDataLoader(valid_dataset,
                                    batch_size=batch_size,
                                    num_workers=num_workers,
                                    is_bucket=False)

    return train_loader, valid_loader
Ejemplo n.º 3
0
def get_datasets(
    meta_dir: str,
    batch_size: int,
    num_workers: int,
    fix_len: int = 0,
    skip_audio: bool = False,
    audio_mask: bool = False,
    skip_last_bucket: bool = True,
    n_buckets: int = 10,
    extra_features: List[Tuple[str, Callable]] = None
) -> Tuple[SpeechDataLoader, SpeechDataLoader]:
    assert os.path.isdir(meta_dir), '{} is not valid directory path!'

    train_file, valid_file = LibriTTSMeta.frame_file_names[1:]

    # load meta file
    train_meta = LibriTTSMeta(os.path.join(meta_dir, train_file))
    valid_meta = LibriTTSMeta(os.path.join(meta_dir, valid_file))

    # create dataset
    train_dataset = SpeechDataset(train_meta,
                                  fix_len=fix_len,
                                  skip_audio=skip_audio,
                                  audio_mask=audio_mask,
                                  extra_features=extra_features)
    valid_dataset = SpeechDataset(valid_meta,
                                  fix_len=fix_len,
                                  skip_audio=skip_audio,
                                  audio_mask=audio_mask,
                                  extra_features=extra_features)

    # create data loader
    train_loader = SpeechDataLoader(train_dataset,
                                    batch_size=batch_size,
                                    n_buckets=n_buckets,
                                    num_workers=num_workers,
                                    skip_last_bucket=skip_last_bucket)
    valid_loader = SpeechDataLoader(valid_dataset,
                                    batch_size=batch_size,
                                    is_bucket=False,
                                    num_workers=num_workers)

    return train_loader, valid_loader