Exemplo n.º 1
0
def create_train_dag(neural_factory,
                     neural_modules,
                     waveglow_params,
                     train_dataset,
                     batch_size,
                     checkpoint_save_freq,
                     cpu_per_dl=1):
    data_preprocessor, waveglow, waveglow_loss = neural_modules

    train_dl_params = copy.deepcopy(waveglow_params["AudioDataLayer"])
    train_dl_params.update(waveglow_params["AudioDataLayer"]["train"])
    del train_dl_params["train"]
    del train_dl_params["eval"]

    data_layer = nemo_tts.AudioDataLayer(
        manifest_filepath=train_dataset,
        batch_size=batch_size,
        num_workers=cpu_per_dl,
        **train_dl_params,
    )

    N = len(data_layer)
    steps_per_epoch = int(N / (batch_size * neural_factory.world_size))
    neural_factory.logger.info('Have {0} examples to train on.'.format(N))

    # Train DAG
    audio, audio_len, = data_layer()
    spec_target, spec_target_len = data_preprocessor(
        input_signal=audio,
        length=audio_len)

    z, log_s_list, log_det_W_list = waveglow(
        mel_spectrogram=spec_target, audio=audio)
    loss_t = waveglow_loss(
        z=z,
        log_s_list=log_s_list,
        log_det_W_list=log_det_W_list)

    # Callbacks needed to print info to console and Tensorboard
    train_callback = nemo.core.SimpleLossLoggerCallback(
        tensors=[loss_t, z, spec_target, spec_target_len],
        print_func=lambda x: print(f"Loss: {x[0].data}"),
        log_to_tb_func=partial(
            waveglow_log_to_tb_func,
            log_images=False),
        tb_writer=neural_factory.tb_writer,
    )

    chpt_callback = nemo.core.CheckpointCallback(
        folder=neural_factory.checkpoint_dir,
        step_freq=checkpoint_save_freq)

    callbacks = [train_callback, chpt_callback]
    return loss_t, callbacks, steps_per_epoch
Exemplo n.º 2
0
def create_eval_dags(neural_factory,
                     neural_modules,
                     waveglow_params,
                     eval_datasets,
                     eval_batch_size,
                     eval_freq,
                     cpu_per_dl=1):
    data_preprocessor, waveglow, _ = neural_modules

    eval_dl_params = copy.deepcopy(waveglow_params["AudioDataLayer"])
    eval_dl_params.update(waveglow_params["AudioDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]

    callbacks = []
    # assemble eval DAGs
    for eval_dataset in eval_datasets:
        data_layer_eval = nemo_tts.AudioDataLayer(
            manifest_filepath=eval_dataset,
            batch_size=eval_batch_size,
            num_workers=cpu_per_dl,
            **eval_dl_params,
        )

        audio, audio_len, = data_layer_eval()
        spec_target, spec_target_len = data_preprocessor(
            input_signal=audio,
            length=audio_len)

        audio_pred, log_s_list, log_det_W_list = waveglow(
            mel_spectrogram=spec_target, audio=audio)

        # create corresponding eval callback
        tagname = os.path.basename(eval_dataset).split(".")[0]
        eval_callback = nemo.core.EvaluatorCallback(
            eval_tensors=[audio_pred, spec_target, spec_target_len],
            user_iter_callback=waveglow_process_eval_batch,
            user_epochs_done_callback=lambda x: x,
            tb_writer_func=partial(
                waveglow_eval_log_to_tb_func,
                tag=tagname,
                mel_fb=data_preprocessor.filter_banks),
            eval_step=eval_freq,
            tb_writer=neural_factory.tb_writer)

        callbacks.append(eval_callback)
    return callbacks