def create_train_dag(neural_factory, neural_modules, waveglow_params, train_dataset, batch_size, checkpoint_save_freq, cpu_per_dl=1): data_preprocessor, waveglow, waveglow_loss = neural_modules train_dl_params = copy.deepcopy(waveglow_params["AudioDataLayer"]) train_dl_params.update(waveglow_params["AudioDataLayer"]["train"]) del train_dl_params["train"] del train_dl_params["eval"] data_layer = nemo_tts.AudioDataLayer( manifest_filepath=train_dataset, batch_size=batch_size, num_workers=cpu_per_dl, **train_dl_params, ) N = len(data_layer) steps_per_epoch = int(N / (batch_size * neural_factory.world_size)) neural_factory.logger.info('Have {0} examples to train on.'.format(N)) # Train DAG audio, audio_len, = data_layer() spec_target, spec_target_len = data_preprocessor( input_signal=audio, length=audio_len) z, log_s_list, log_det_W_list = waveglow( mel_spectrogram=spec_target, audio=audio) loss_t = waveglow_loss( z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list) # Callbacks needed to print info to console and Tensorboard train_callback = nemo.core.SimpleLossLoggerCallback( tensors=[loss_t, z, spec_target, spec_target_len], print_func=lambda x: print(f"Loss: {x[0].data}"), log_to_tb_func=partial( waveglow_log_to_tb_func, log_images=False), tb_writer=neural_factory.tb_writer, ) chpt_callback = nemo.core.CheckpointCallback( folder=neural_factory.checkpoint_dir, step_freq=checkpoint_save_freq) callbacks = [train_callback, chpt_callback] return loss_t, callbacks, steps_per_epoch
def create_eval_dags(neural_factory, neural_modules, waveglow_params, eval_datasets, eval_batch_size, eval_freq, cpu_per_dl=1): data_preprocessor, waveglow, _ = neural_modules eval_dl_params = copy.deepcopy(waveglow_params["AudioDataLayer"]) eval_dl_params.update(waveglow_params["AudioDataLayer"]["eval"]) del eval_dl_params["train"] del eval_dl_params["eval"] callbacks = [] # assemble eval DAGs for eval_dataset in eval_datasets: data_layer_eval = nemo_tts.AudioDataLayer( manifest_filepath=eval_dataset, batch_size=eval_batch_size, num_workers=cpu_per_dl, **eval_dl_params, ) audio, audio_len, = data_layer_eval() spec_target, spec_target_len = data_preprocessor( input_signal=audio, length=audio_len) audio_pred, log_s_list, log_det_W_list = waveglow( mel_spectrogram=spec_target, audio=audio) # create corresponding eval callback tagname = os.path.basename(eval_dataset).split(".")[0] eval_callback = nemo.core.EvaluatorCallback( eval_tensors=[audio_pred, spec_target, spec_target_len], user_iter_callback=waveglow_process_eval_batch, user_epochs_done_callback=lambda x: x, tb_writer_func=partial( waveglow_eval_log_to_tb_func, tag=tagname, mel_fb=data_preprocessor.filter_banks), eval_step=eval_freq, tb_writer=neural_factory.tb_writer) callbacks.append(eval_callback) return callbacks