Exemplo n.º 1
0
def make_data_loader(data_root, config):
    # construct meta data
    meta = LJSpeechMetaData(data_root)

    # filter it!
    min_text_length = config["meta_data"]["min_text_length"]
    meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length)

    # transform meta data into meta data
    c = config["transform"]
    transform = Transform(
        replace_pronunciation_prob=c["replace_pronunciation_prob"],
        sample_rate=c["sample_rate"],
        preemphasis=c["preemphasis"],
        n_fft=c["n_fft"],
        win_length=c["win_length"],
        hop_length=c["hop_length"],
        fmin=c["fmin"],
        fmax=c["fmax"],
        n_mels=c["n_mels"],
        min_level_db=c["min_level_db"],
        ref_level_db=c["ref_level_db"],
        max_norm=c["max_norm"],
        clip_norm=c["clip_norm"])
    ljspeech = CacheDataset(TransformDataset(meta, transform))

    # use meta data's text length as a sort key for the sampler
    batch_size = config["train"]["batch_size"]
    text_lengths = [len(example[2]) for example in meta]
    sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths,
                                                         batch_size)

    env = dg.parallel.ParallelEnv()
    num_trainers = env.nranks
    local_rank = env.local_rank
    sampler = BucketSampler(
        text_lengths, batch_size, num_trainers=num_trainers, rank=local_rank)

    # some model hyperparameters affect how we process data
    model_config = config["model"]
    collector = DataCollector(
        downsample_factor=model_config["downsample_factor"],
        r=model_config["outputs_per_step"])
    ljspeech_loader = DataCargo(
        ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler)
    loader = fluid.io.DataLoader.from_generator(capacity=10, return_list=True)
    loader.set_batch_generator(
        ljspeech_loader, places=fluid.framework._current_expected_place())
    return loader
Exemplo n.º 2
0
    args = parser.parse_args()
    with open(args.config, 'rt') as f:
        config = ruamel.yaml.safe_load(f)

    ljspeech_meta = LJSpeechMetaData(args.data)

    data_config = config["data"]
    sample_rate = data_config["sample_rate"]
    n_fft = data_config["n_fft"]
    win_length = data_config["win_length"]
    hop_length = data_config["hop_length"]
    n_mels = data_config["n_mels"]
    train_clip_seconds = data_config["train_clip_seconds"]
    transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels)
    ljspeech = TransformDataset(ljspeech_meta, transform)

    valid_size = data_config["valid_size"]
    ljspeech_valid = SliceDataset(ljspeech, 0, valid_size)
    ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech))

    model_config = config["model"]
    n_loop = model_config["n_loop"]
    n_layer = model_config["n_layer"]
    filter_size = model_config["filter_size"]
    context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)])
    print("context size is {} samples".format(context_size))
    train_batch_fn = DataCollector(context_size, sample_rate, hop_length,
                                   train_clip_seconds)
    valid_batch_fn = DataCollector(context_size,
                                   sample_rate,