def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, global_step, epoch, criterion_gst=None, optimizer_gst=None): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) if c.use_speaker_embedding: speaker_mapping = load_speaker_mapping(OUT_PATH) model.train() epoch_time = 0 avg_postnet_loss = 0 avg_decoder_loss = 0 avg_stop_loss = 0 avg_gst_loss = 0 avg_step_time = 0 avg_loader_time = 0 print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) if use_cuda: batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() for num_iter, data in enumerate(data_loader): start_time = time.time() # setup input data text_input = data[0] text_lengths = data[1] speaker_names = data[2] linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) loader_time = time.time() - end_time if c.use_speaker_embedding: speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_ids = torch.LongTensor(speaker_ids) else: speaker_ids = None # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) global_step += 1 # setup lr if c.lr_decay: scheduler.step() optimizer.zero_grad() if optimizer_gst: optimizer_gst.zero_grad() if optimizer_st: optimizer_st.zero_grad() # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) # forward pass model decoder_output, postnet_output, alignments, stop_tokens, text_gst = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) gst_loss = torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: postnet_loss = criterion(postnet_output, linear_input, mel_lengths) else: postnet_loss = criterion(postnet_output, mel_input, mel_lengths) else: decoder_loss = criterion(decoder_output, mel_input) if c.model in ["Tacotron", "TacotronGST"]: postnet_loss = criterion(postnet_output, linear_input) else: postnet_loss = criterion(postnet_output, mel_input) loss = decoder_loss + postnet_loss if not c.separate_stopnet and c.stopnet: loss += stop_loss if c.text_gst and criterion_gst and optimizer_gst: mel_gst, _ = model.gst(mel_input) gst_loss = criterion_gst(text_gst, mel_gst.squeeze().detach()) gst_loss.backward() optimizer_gst.step() loss.backward() optimizer, current_lr = weight_decay(optimizer, c.wd) grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() # backpass and check the grad norm for stop loss if c.separate_stopnet: stop_loss.backward() optimizer_st, _ = weight_decay(optimizer_st, c.wd) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() else: grad_norm_st = 0 step_time = time.time() - start_time epoch_time += step_time if global_step % c.print_step == 0: print( " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} GSTLoss:{:.5f} GradNorm:{:.5f} " "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} " "LoaderTime:{:.2f} LR:{:.6f}".format( num_iter, batch_n_iter, global_step, loss.item(), postnet_loss.item(), decoder_loss.item(), stop_loss.item(), gst_loss.item(), grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, loader_time, current_lr), flush=True) # aggregate losses from processes if num_gpus > 1: postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) gst_loss = reduce_tensor(gst_loss.data, num_gpus) if c.text_gst else gst_loss loss = reduce_tensor(loss.data, num_gpus) stop_loss = reduce_tensor(stop_loss.data, num_gpus) if c.stopnet else stop_loss if args.rank == 0: avg_postnet_loss += float(postnet_loss.item()) avg_decoder_loss += float(decoder_loss.item()) avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()) avg_gst_loss += float(gst_loss.item()) avg_step_time += step_time avg_loader_time += loader_time # Plot Training Iter Stats # reduce TB load if global_step % 10 == 0: iter_stats = {"loss_posnet": postnet_loss.item(), "loss_decoder": decoder_loss.item(), "gst_loss" : gst_loss.item(), "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time} tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model, optimizer, optimizer_st, optimizer_gst, postnet_loss.item(), OUT_PATH, global_step, epoch) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img) } tb_logger.tb_train_figures(global_step, figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_mel_spectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) end_time = time.time() avg_postnet_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1) avg_gst_loss /= (num_iter + 1) avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss avg_step_time /= (num_iter + 1) avg_loader_time /= (num_iter + 1) # print epoch stats print( " | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} AvgGSTLoss:{:.5f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, avg_total_loss, avg_postnet_loss, avg_decoder_loss, avg_gst_loss, avg_stop_loss, epoch_time, avg_step_time, avg_loader_time), flush=True) # Plot Epoch Stats if args.rank == 0: # Plot Training Epoch Stats epoch_stats = {"loss_postnet": avg_postnet_loss, "loss_decoder": avg_decoder_loss, "stop_loss": avg_stop_loss, "gst_loss" : avg_gst_loss, "epoch_time": epoch_time} tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) return avg_postnet_loss, global_step
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch, use_half=False): data_loader = setup_loader(is_val=False, verbose=(epoch == 0), use_half=use_half) model.train() epoch_time = 0 avg_postnet_loss = 0 avg_decoder_loss = 0 avg_stop_loss = 0 avg_step_time = 0 print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) start_time = time.time() for num_iter, data in enumerate(data_loader): # setup input data text_input = data[0] text_lengths = data[1] linear_input = data[2] if c.model == "Tacotron" else None mel_input = data[3] if not use_half else data[3].type(torch.half) mel_lengths = data[4] if not use_half else data[4].type(torch.half) stop_targets = data[5] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) stop_targets = stop_targets if not use_half else stop_targets.type( torch.half) current_step = num_iter + args.restore_step + \ epoch * len(data_loader) + 1 # setup lr if c.lr_decay: scheduler.step() optimizer.zero_grad() if optimizer_st: optimizer_st.zero_grad() # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) linear_input = linear_input.cuda( non_blocking=True) if c.model == "Tacotron" else None stop_targets = stop_targets.cuda(non_blocking=True) decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) if c.model == "Tacotron": postnet_loss = criterion(postnet_output, linear_input, mel_lengths) else: postnet_loss = criterion(postnet_output, mel_input, mel_lengths) else: decoder_loss = criterion(decoder_output, mel_input) if c.model == "Tacotron": postnet_loss = criterion(postnet_output, linear_input) else: postnet_loss = criterion(postnet_output, mel_input) USE_HALF_LOSS_SCALE = 10.0 if use_half: postnet_loss = postnet_loss * USE_HALF_LOSS_SCALE decoder_loss = decoder_loss * USE_HALF_LOSS_SCALE loss = decoder_loss + postnet_loss if not c.separate_stopnet and c.stopnet: loss += stop_loss loss.backward() optimizer, current_lr = weight_decay(optimizer, c.wd) grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() # backpass and check the grad norm for stop loss if c.separate_stopnet: USE_HALF_STOP_LOSS_SCALE = 1 stop_loss = stop_loss * USE_HALF_STOP_LOSS_SCALE stop_loss.backward() optimizer_st, _ = weight_decay(optimizer_st, c.wd) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() else: grad_norm_st = 0 step_time = time.time() - start_time start_time = time.time() epoch_time += step_time if current_step % c.print_step == 0: print( " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}" .format(num_iter, batch_n_iter, current_step, loss.item(), postnet_loss.item(), decoder_loss.item(), stop_loss.item(), grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr), flush=True) # aggregate losses from processes if num_gpus > 1: postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) loss = reduce_tensor(loss.data, num_gpus) stop_loss = reduce_tensor(stop_loss.data, num_gpus) if c.stopnet else stop_loss if args.rank == 0: avg_postnet_loss += float(postnet_loss.item()) avg_decoder_loss += float(decoder_loss.item()) avg_stop_loss += stop_loss if type(stop_loss) is float else float( stop_loss.item()) avg_step_time += step_time # Plot Training Iter Stats iter_stats = { "loss_posnet": postnet_loss.item(), "loss_decoder": decoder_loss.item(), "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time } tb_logger.tb_train_iter_stats(current_step, iter_stats) if current_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model, optimizer, optimizer_st, postnet_loss.item(), OUT_PATH, current_step, epoch) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().type( torch.float).numpy() gt_spec = linear_input[0].data.cpu().type(torch.float).numpy( ) if c.model == "Tacotron" else mel_input[0].data.cpu().type( torch.float).numpy() align_img = alignments[0].data.cpu().type(torch.float).numpy() figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img) } tb_logger.tb_train_figures(current_step, figures) # Sample audio if c.model == "Tacotron": train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_mel_spectrogram(const_spec.T) tb_logger.tb_train_audios(current_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) avg_postnet_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1) avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss avg_step_time /= (num_iter + 1) # print epoch stats print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStepTime:{:.2f}".format(current_step, avg_total_loss, avg_postnet_loss, avg_decoder_loss, avg_stop_loss, epoch_time, avg_step_time), flush=True) # Plot Epoch Stats if args.rank == 0: # Plot Training Epoch Stats epoch_stats = { "loss_postnet": avg_postnet_loss, "loss_decoder": avg_decoder_loss, "stop_loss": avg_stop_loss, "epoch_time": epoch_time } tb_logger.tb_train_epoch_stats(current_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, current_step) return avg_postnet_loss, current_step
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch): data_loader = setup_loader(is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 avg_stop_loss = 0 avg_step_time = 0 print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq']) if num_gpus > 0: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) for num_iter, data in enumerate(data_loader): start_time = time.time() # setup input data text_input = data[0] text_lengths = data[1] linear_input = data[2] mel_input = data[3] mel_lengths = data[4] stop_targets = data[5] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() current_step = num_iter + args.restore_step + \ epoch * len(data_loader) + 1 # setup lr if c.lr_decay: scheduler.step() optimizer.zero_grad() optimizer_st.zero_grad() # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) linear_input = linear_input.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True) # compute mask for padding mask = sequence_mask(text_lengths) # forward pass mel_output, linear_output, alignments, stop_tokens = model( text_input, mel_input, mask) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) mel_loss = criterion(mel_output, mel_input, mel_lengths) linear_loss = (1 - c.loss_weight) * criterion(linear_output, linear_input, mel_lengths)\ + c.loss_weight * criterion(linear_output[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq], mel_lengths) loss = mel_loss + linear_loss # backpass and check the grad norm for spec losses loss.backward(retain_graph=True) optimizer, current_lr = weight_decay(optimizer, c.wd) grad_norm, _ = check_update(model, 1.0) optimizer.step() # backpass and check the grad norm for stop loss stop_loss.backward() optimizer_st, _ = weight_decay(optimizer_st, c.wd) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() step_time = time.time() - start_time epoch_time += step_time if current_step % c.print_step == 0: print( " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}" .format(num_iter, batch_n_iter, current_step, loss.item(), linear_loss.item(), mel_loss.item(), stop_loss.item(), grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr), flush=True) # aggregate losses from processes if num_gpus > 1: linear_loss = reduce_tensor(linear_loss.data, num_gpus) mel_loss = reduce_tensor(mel_loss.data, num_gpus) loss = reduce_tensor(loss.data, num_gpus) stop_loss = reduce_tensor(stop_loss.data, num_gpus) if args.rank == 0: avg_linear_loss += float(linear_loss.item()) avg_mel_loss += float(mel_loss.item()) avg_stop_loss += stop_loss.item() avg_step_time += step_time # Plot Training Iter Stats iter_stats = { "loss_posnet": linear_loss.item(), "loss_decoder": mel_loss.item(), "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time } tb_logger.tb_train_iter_stats(current_step, iter_stats) if current_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(), OUT_PATH, current_step, epoch) # Diagnostic visualizations const_spec = linear_output[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img) } tb_logger.tb_train_figures(current_step, figures) # Sample audio tb_logger.tb_train_audios( current_step, {'TrainAudio': ap.inv_spectrogram(const_spec.T)}, c.audio["sample_rate"]) avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1) avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss avg_step_time /= (num_iter + 1) # print epoch stats print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStepTime:{:.2f}".format(current_step, avg_total_loss, avg_linear_loss, avg_mel_loss, avg_stop_loss, epoch_time, avg_step_time), flush=True) # Plot Epoch Stats if args.rank == 0: # Plot Training Epoch Stats epoch_stats = { "loss_postnet": avg_linear_loss, "loss_decoder": avg_mel_loss, "stop_loss": avg_stop_loss, "epoch_time": epoch_time } tb_logger.tb_train_epoch_stats(current_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, current_step) return avg_linear_loss, current_step