def train(model, criterion, optimizer, scheduler, ap, global_step): data_loader = setup_loader(ap, is_val=False, verbose=True) model.train() epoch_time = 0 best_loss = float('inf') avg_loss = 0 end_time = time.time() for _, data in enumerate(data_loader): start_time = time.time() # setup input data inputs = data[0] loader_time = time.time() - end_time global_step += 1 # setup lr if c.lr_decay: scheduler.step() optimizer.zero_grad() # dispatch data to GPU if use_cuda: inputs = inputs.cuda(non_blocking=True) # labels = labels.cuda(non_blocking=True) # forward pass model outputs = model(inputs) # loss computation loss = criterion( outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1)) loss.backward() grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() step_time = time.time() - start_time epoch_time += step_time avg_loss = 0.01 * loss.item( ) + 0.99 * avg_loss if avg_loss != 0 else loss.item() current_lr = optimizer.param_groups[0]['lr'] if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats train_stats = { "loss": avg_loss, "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time } tb_logger.tb_train_epoch_stats(global_step, train_stats) figures = { # FIXME: not constant "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), } tb_logger.tb_train_figures(global_step, figures) if global_step % c.print_step == 0: print( " | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} " "StepTime:{:.2f} LoaderTime:{:.2f} LR:{:.6f}".format( global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, current_lr), flush=True) # save best model best_loss = save_best_model(model, optimizer, avg_loss, best_loss, OUT_PATH, global_step) end_time = time.time() return avg_loss, global_step
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data ( text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length, ) = format_data(data) loader_time = time.time() - end_time global_step += 1 # setup lr if c.noam_schedule: scheduler.step() optimizer.zero_grad() if optimizer_st: optimizer_st.zero_grad() with torch.cuda.amp.autocast(enabled=c.mixed_precision): # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: ( decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward, ) = model( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings, ) else: decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings, ) decoder_backward_output = None alignments_backward = None # set the [alignment] lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: alignment_lengths = ( mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r else: alignment_lengths = mel_lengths // model.decoder.r # compute loss loss_dict = criterion( postnet_output, decoder_output, mel_input, linear_input, stop_tokens, stop_targets, mel_lengths, decoder_backward_output, alignments, alignment_lengths, alignments_backward, text_lengths, ) # check nan loss if torch.isnan(loss_dict["loss"]).any(): raise RuntimeError(f"Detected NaN loss at step {global_step}.") # optimizer step if c.mixed_precision: # model optimizer step in mixed precision mode scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) scaler.step(optimizer) scaler.update() # stopnet optimizer step if c.separate_stopnet: scaler_st.scale(loss_dict["stopnet_loss"]).backward() scaler.unscale_(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) scaler_st.step(optimizer) scaler_st.update() else: grad_norm_st = 0 else: # main model optimizer step loss_dict["loss"].backward() optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() # stopnet optimizer step if c.separate_stopnet: loss_dict["stopnet_loss"].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() else: grad_norm_st = 0 # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: loss_dict["postnet_loss"] = reduce_tensor( loss_dict["postnet_loss"].data, num_gpus) loss_dict["decoder_loss"] = reduce_tensor( loss_dict["decoder_loss"].data, num_gpus) loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) loss_dict["stopnet_loss"] = (reduce_tensor( loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else loss_dict["stopnet_loss"]) # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() loss_dict = loss_dict_new # update avg stats update_train_values = dict() for key, value in loss_dict.items(): update_train_values["avg_" + key] = value update_train_values["avg_loader_time"] = loader_time update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { "max_spec_length": [max_spec_length, 1], # value, precision "max_text_length": [max_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: iter_stats = { "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time, } iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint( model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, model_loss=loss_dict["postnet_loss"], characters=model_characters, scaler=scaler.state_dict() if c.mixed_precision else None, ) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() gt_spec = (linear_input[0].data.cpu().numpy() if c.model in [ "Tacotron", "TacotronGST" ] else mel_input[0].data.cpu().numpy()) align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(const_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), } if c.bidirectional_decoder or c.double_decoder_consistency: figures["alignment_backward"] = plot_alignment( alignments_backward[0].data.cpu().numpy(), output_fig=False) tb_logger.tb_train_figures(global_step, figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Epoch Stats if args.rank == 0: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): model.train() best_loss = float("inf") avg_loader_time = 0 end_time = time.time() for epoch in range(c.epochs): tot_loss = 0 epoch_time = 0 for _, data in enumerate(data_loader): start_time = time.time() # setup input data inputs, labels = data # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] labels = torch.transpose( labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) inputs = torch.transpose( inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) # ToDo: move it to a unit test # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) # idx = 0 # for j in range(0, c.num_classes_in_batch, 1): # for i in range(j, len(labels), c.num_classes_in_batch): # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): # print("Invalid") # print(labels) # exit() # idx += 1 # labels = labels_converted # inputs = inputs_converted loader_time = time.time() - end_time global_step += 1 # setup lr if c.lr_decay: scheduler.step() optimizer.zero_grad() # dispatch data to GPU if use_cuda: inputs = inputs.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) # forward pass model outputs = model(inputs) # loss computation loss = criterion( outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels) loss.backward() grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() step_time = time.time() - start_time epoch_time += step_time # acumulate the total epoch loss tot_loss += loss.item() # Averaged Loader Time num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 avg_loader_time = ( 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time) current_lr = optimizer.param_groups[0]["lr"] if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats train_stats = { "loss": loss.item(), "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time, "avg_loader_time": avg_loader_time, } dashboard_logger.train_epoch_stats(global_step, train_stats) figures = { "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), } dashboard_logger.train_figures(global_step, figures) if global_step % c.print_step == 0: print( " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} " "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}" .format(global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr), flush=True, ) if global_step % c.save_step == 0: # save model save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) end_time = time.time() print("") print( ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time), flush=True, ) # evaluation if c.run_eval: model.eval() eval_loss = evaluation(model, criterion, eval_data_loader, global_step) print("\n\n") print("--> EVAL PERFORMANCE") print( " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss), flush=True, ) # save the best checkpoint best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch) model.train() return best_loss, global_step
def train(model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, amp, speaker_mapping=None): data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data( data, speaker_mapping) loader_time = time.time() - end_time global_step += 1 # setup lr if c.noam_schedule: scheduler.step() optimizer.zero_grad() if optimizer_st: optimizer_st.zero_grad() # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) else: decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_backward_output = None alignments_backward = None # set the [alignment] lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: alignment_lengths = ( mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r else: alignment_lengths = mel_lengths // model.decoder.r # compute loss loss_dict = criterion(postnet_output, decoder_output, mel_input, linear_input, stop_tokens, stop_targets, mel_lengths, decoder_backward_output, alignments, alignment_lengths, alignments_backward, text_lengths) # backward pass if amp is not None: with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss: scaled_loss.backward() else: loss_dict['loss'].backward() optimizer, current_lr = adam_weight_decay(optimizer) if amp: amp_opt_params = amp.master_params(optimizer) else: amp_opt_params = None grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params) optimizer.step() # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) loss_dict['align_error'] = align_error # backpass and check the grad norm for stop loss if c.separate_stopnet: loss_dict['stopnet_loss'].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) if amp: amp_opt_params = amp.master_params(optimizer) else: amp_opt_params = None grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0, amp_opt_params=amp_opt_params) optimizer_st.step() else: grad_norm_st = 0 step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: loss_dict['postnet_loss'] = reduce_tensor( loss_dict['postnet_loss'].data, num_gpus) loss_dict['decoder_loss'] = reduce_tensor( loss_dict['decoder_loss'].data, num_gpus) loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus) loss_dict['stopnet_loss'] = reduce_tensor( loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss'] # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() loss_dict = loss_dict_new # update avg stats update_train_values = dict() for key, value in loss_dict.items(): update_train_values['avg_' + key] = value update_train_values['avg_loader_time'] = loader_time update_train_values['avg_step_time'] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { "avg_spec_length": [avg_spec_length, 1], # value, precision "avg_text_length": [avg_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: iter_stats = { "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time } iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint( model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, model_loss=loss_dict['postnet_loss'], amp_state_dict=amp.state_dict() if amp else None) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ "Tacotron", "TacotronGST" ] else mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(const_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), } if c.bidirectional_decoder or c.double_decoder_consistency: figures["alignment_backward"] = plot_alignment( alignments_backward[0].data.cpu().numpy(), output_fig=False) tb_logger.tb_train_figures(global_step, figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Epoch Stats if args.rank == 0: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step
def train(model, optimizer, scheduler, criterion, data_loader, global_step): model.train() epoch_time = 0 best_loss = float("inf") avg_loss = 0 avg_loss_all = 0 avg_loader_time = 0 end_time = time.time() for _, data in enumerate(data_loader): start_time = time.time() # setup input data inputs, labels = data loader_time = time.time() - end_time global_step += 1 # setup lr if c.lr_decay: scheduler.step() optimizer.zero_grad() # dispatch data to GPU if use_cuda: inputs = inputs.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) # forward pass model outputs = model(inputs) # loss computation loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels) loss.backward() grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() step_time = time.time() - start_time epoch_time += step_time # Averaged Loss and Averaged Loader Time avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item() num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 avg_loader_time = ( 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time ) current_lr = optimizer.param_groups[0]["lr"] if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats train_stats = { "loss": avg_loss, "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time, "avg_loader_time": avg_loader_time, } tb_logger.tb_train_epoch_stats(global_step, train_stats) figures = { # FIXME: not constant "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), } tb_logger.tb_train_figures(global_step, figures) if global_step % c.print_step == 0: print( " | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} " "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr ), flush=True, ) avg_loss_all += avg_loss if global_step >= c.max_train_step or global_step % c.save_step == 0: # save best model only best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step) avg_loss_all = 0 if global_step >= c.max_train_step: break end_time = time.time() return avg_loss, global_step
def train(model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 train_values = { 'avg_postnet_loss': 0, 'avg_decoder_loss': 0, 'avg_stopnet_loss': 0, 'avg_align_error': 0, 'avg_step_time': 0, 'avg_loader_time': 0 } if c.bidirectional_decoder: train_values['avg_decoder_b_loss'] = 0 # decoder backward loss train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss if c.ga_alpha > 0: train_values['avg_ga_loss'] = 0 # guidede attention loss keep_avg = KeepAverage() keep_avg.add_values(train_values) if use_cuda: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data( data) loader_time = time.time() - end_time global_step += 1 # setup lr if c.noam_schedule: scheduler.step() optimizer.zero_grad() if optimizer_st: optimizer_st.zero_grad() # forward pass model if c.bidirectional_decoder: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) else: decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) decoder_backward_output = None # set the alignment lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: alignment_lengths = ( mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r else: alignment_lengths = mel_lengths // model.decoder.r # compute loss loss_dict = criterion(postnet_output, decoder_output, mel_input, linear_input, stop_tokens, stop_targets, mel_lengths, decoder_backward_output, alignments, alignment_lengths, text_lengths) if c.bidirectional_decoder: keep_avg.update_values({ 'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item() }) if c.ga_alpha > 0: keep_avg.update_values( {'avg_ga_loss': loss_dict['ga_loss'].item()}) # backward pass loss_dict['loss'].backward() optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) keep_avg.update_value('avg_align_error', align_error) loss_dict['align_error'] = align_error # backpass and check the grad norm for stop loss if c.separate_stopnet: loss_dict['stopnet_loss'].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() else: grad_norm_st = 0 step_time = time.time() - start_time epoch_time += step_time # update avg stats update_train_values = { 'avg_postnet_loss': float(loss_dict['postnet_loss'].item()), 'avg_decoder_loss': float(loss_dict['decoder_loss'].item()), 'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \ if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()), 'avg_step_time': step_time, 'avg_loader_time': loader_time } keep_avg.update_values(update_train_values) if global_step % c.print_step == 0: c_logger.print_train_step(batch_n_iter, num_iter, global_step, avg_spec_length, avg_text_length, step_time, loader_time, current_lr, loss_dict, keep_avg.avg_values) # aggregate losses from processes if num_gpus > 1: loss_dict['postnet_loss'] = reduce_tensor( loss_dict['postnet_loss'].data, num_gpus) loss_dict['decoder_loss'] = reduce_tensor( loss_dict['decoder_loss'].data, num_gpus) loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus) loss_dict['stopnet_loss'] = reduce_tensor( loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss'] if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % 10 == 0: iter_stats = { "loss_posnet": loss_dict['postnet_loss'].item(), "loss_decoder": loss_dict['decoder_loss'].item(), "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time } tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint( model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, model_loss=loss_dict['postnet_loss'].item()) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ "Tacotron", "TacotronGST" ] else mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img), } if c.bidirectional_decoder: figures["alignment_backward"] = plot_alignment( alignments_backward[0].data.cpu().numpy()) tb_logger.tb_train_figures(global_step, figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Epoch Stats if args.rank == 0: # Plot Training Epoch Stats epoch_stats = { "loss_postnet": keep_avg['avg_postnet_loss'], "loss_decoder": keep_avg['avg_decoder_loss'], "stopnet_loss": keep_avg['avg_stopnet_loss'], "alignment_score": keep_avg['avg_align_error'], "epoch_time": epoch_time } if c.ga_alpha > 0: epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss'] tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step
def train(model, criterion, optimizer, scheduler, ap, global_step, epoch, speaker_mapping=None): data_loader = setup_loader(ap, 1, is_val=False, verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None for num_iter, data in enumerate(data_loader): start_time = time.time() # format data text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ avg_text_length, avg_spec_length, attn_mask = format_data(data) loader_time = time.time() - end_time global_step += 1 optimizer.zero_grad() # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) # backward pass with loss scaling if c.mixed_precision: scaler.scale(loss_dict['loss']).backward() scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: loss_dict['loss'].backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() # setup lr if c.noam_schedule: scheduler.step() # current_lr current_lr = optimizer.param_groups[0]['lr'] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) loss_dict['align_error'] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, num_gpus) # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() loss_dict = loss_dict_new # update avg stats update_train_values = dict() for key, value in loss_dict.items(): update_train_values['avg_' + key] = value update_train_values['avg_loader_time'] = loader_time update_train_values['avg_step_time'] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { "avg_spec_length": [avg_spec_length, 1], # value, precision "avg_text_length": [avg_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: iter_stats = { "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time } iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_loss=loss_dict['loss']) # Diagnostic visualizations # direct pass on model for spec predictions target_speaker = None if speaker_ids is None else speaker_ids[: 1] spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) const_spec = spec_pred[0].data.cpu().numpy() gt_spec = gt_spec[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img), } tb_logger.tb_train_figures(global_step, figures) # Sample audio train_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Epoch Stats if args.rank == 0: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step