def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model if hp.use_spk_embed: n_pkers = len(dataset.spk_table.keys()) model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device) else: model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() #melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Init evaluation directory eval_path = hp.eval_path if not os.path.exists(eval_path): os.makedirs(eval_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 print("step : {}".format(current_step), end='\r', flush=True) ### Get Data ### if hp.use_spk_embed: spk_ids = torch.tensor(data_of_batch["spk_ids"]).to( torch.int64).to(device) else: spk_ids = None text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) ### Forward ### mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, spk_ids) ### Cal Loss ### mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") ## Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue ### Update weights ### nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() ### Print ### if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) ### Save model ### if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) ### Synth ### if current_step % hp.synth_step == 0: length = mel_len[0].item() print("step: {} , length {}, {}".format( current_step, length, mel_len)) mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) spk_id = dataset.inv_spk_table[int(spk_ids[0])] if hp.vocoder == 'melgan': vocoder.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) elif hp.vocoder == 'waveglow': vocoder.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, hp.vocoder))) vocoder.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) ### Evaluation ### if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() ### Time ### end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') device = 'cuda' # Get dataset dataset = Dataset("train.txt") loader = DataLoaderX(dataset, batch_size=hp.batch_size * 4, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=8) # Define model model = nn.DataParallel(FastSpeech2()).to(device) # model = FastSpeech2().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime()) train_logger = SummaryWriter(log_dir='log/train/' + current_time) val_logger = SummaryWriter(log_dir='log/validation/' + current_time) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() current_step0 = 0 # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * len(batchs) + j + args.restore_step + \ epoch * len(loader)*len(batchs) + 1 # Get Data condition = torch.from_numpy( data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy( data_of_batch["mel_refer"]).float().to(device) if hp.vocoder == 'WORLD': ap_target = torch.from_numpy( data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy( data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) if hp.vocoder == 'WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target, src_mask=src_mask, ap_mask=ap_mask, sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( condition, mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() if hp.vocoder == 'WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: # f_total_loss.write(str(t_l)+"\n") # with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: # f_mel_loss.write(str(m_l)+"\n") # with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: # f_mel_postnet_loss.write(str(m_p_l)+"\n") # with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: # f_d_loss.write(str(d_l)+"\n") # with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: # f_f_loss.write(str(f_l)+"\n") # with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: # f_e_loss.write(str(e_l)+"\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch[{}/{}],Step[{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) if hp.vocoder == 'WORLD': str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l) else: str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "T:{:.1f}s,ETA:{:.1f}s.".format( (Now - Start) / (current_step - current_step0), (total_step - current_step) * np.mean(Time)) print("" + str1 + str2 + str3 + '') # with open(os.path.join(log_path, "log.txt"), "a") as f_log: # f_log.write(str1 + "\n") # f_log.write(str2 + "\n") # f_log.write(str3 + "\n") # f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) if hp.vocoder == 'WORLD': train_logger.add_scalar('Loss/ap_loss', ap_l, current_step) train_logger.add_scalar('Loss/sp_loss', sp_l, current_step) train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l, current_step) else: train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0 or current_step == 20: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0 or current_step == 5: length = mel_len[0].item() if hp.vocoder == 'WORLD': ap_target_torch = ap_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) ap_torch = ap_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_target_torch = sp_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) sp_torch = sp_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_postnet_torch = sp_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) else: mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose( 0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() if hp.vocoder == 'melgan': utils.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'WORLD': # ap=np.swapaxes(ap,0,1) # sp=np.swapaxes(sp,0,1) wav = utils.world_infer( np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0, 1), f0_output) sf.write( os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) wav = utils.world_infer( np.swapaxes(ap_target_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_target_torch[0].cpu().numpy(), 0, 1), f0) sf.write( os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) utils.plot_data([ (sp_postnet_torch[0].cpu().numpy(), f0_output, energy_output), (sp_target_torch[0].cpu().numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) plt.matshow(sp_postnet_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'sp_postnet_{}.png'.format(current_step))) plt.matshow(ap_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'ap_{}.png'.format(current_step))) plt.matshow( variance_adaptor_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step))) # plt.matshow(decoder_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step))) plt.cla() fout = open( os.path.join(synth_path, 'D_{}.txt'.format(current_step)), 'w') fout.write( str(log_duration_output[0].detach().cpu().numpy()) + '\n') fout.write(str(D[0].detach().cpu().numpy()) + '\n') fout.write( str(condition[0, :, 2].detach().cpu().numpy()) + '\n') fout.close() # if current_step % hp.eval_step == 0 or current_step==20: # model.eval() # with torch.no_grad(): # if hp.vocoder=='WORLD': # d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step) # val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step) # val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # else: # d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + m_l + m_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/mel_loss', m_l, current_step) # val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # model.train() # if current_step%10==0: # print(energy_output[0],energy[0]) end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define model model = nn.DataParallel(FastSpeech()).to(device) #tacotron2 = get_tacotron2() print("FastSpeech and Tacotron2 Have Been Defined") num_param = sum(param.numel() for param in model.parameters()) print('Number of FastSpeech Parameters:', num_param) # Get dataset dataset = FastSpeechDataset() # Optimizer and loss optimizer = torch.optim.Adam( model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduled_optim = ScheduledOptim(optimizer, hp.word_vec_dim, hp.n_warm_up_step, args.restore_step) fastspeech_loss = FastSpeechLoss().to(device) print("Defined Optimizer and Loss Function.") # Get training loader print("Get Training Loader") training_loader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=cpu_count()) try: checkpoint = torch.load(os.path.join( hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n------Model Restored at Step %d------\n" % args.restore_step) except: print("\n------Start New Training------\n") if not os.path.exists(hp.checkpoint_path): os.mkdir(hp.checkpoint_path) # Init logger if not os.path.exists(hp.logger_path): os.mkdir(hp.logger_path) # Training model = model.train() total_step = hp.epochs * len(training_loader) Time = np.array(list()) Start = time.clock() summary = SummaryWriter() for epoch in range(hp.epochs): for i, data_of_batch in enumerate(training_loader): start_time = time.clock() current_step = i + args.restore_step + \ epoch * len(training_loader) + 1 # Init scheduled_optim.zero_grad() if not hp.pre_target: # Prepare Data src_seq = data_of_batch["texts"] src_pos = data_of_batch["pos"] mel_tgt = data_of_batch["mels"] src_seq = torch.from_numpy(src_seq).long().to(device) src_pos = torch.from_numpy(src_pos).long().to(device) mel_tgt = torch.from_numpy(mel_tgt).float().to(device) alignment_target = get_alignment( src_seq, tacotron2).float().to(device) # For Data Parallel mel_max_len = mel_tgt.size(1) else: # Prepare Data src_seq = data_of_batch["texts"] src_pos = data_of_batch["pos"] mel_tgt = data_of_batch["mels"] alignment_target = data_of_batch["alignment"] # print(alignment_target) # print(alignment_target.shape) # print(mel_tgt.shape) # print(src_seq.shape) # print(src_seq) src_seq = torch.from_numpy(src_seq).long().to(device) src_pos = torch.from_numpy(src_pos).long().to(device) mel_tgt = torch.from_numpy(mel_tgt).float().to(device) alignment_target = torch.from_numpy( alignment_target).float().to(device) # For Data Parallel mel_max_len = mel_tgt.size(1) # print(alignment_target.shape) # Forward mel_output, mel_output_postnet, duration_predictor_output = model( src_seq, src_pos, mel_max_length=mel_max_len, length_target=alignment_target) # Cal Loss mel_loss, mel_postnet_loss, duration_predictor_loss = fastspeech_loss( mel_output, mel_output_postnet, duration_predictor_output, mel_tgt, alignment_target) total_loss = mel_loss + mel_postnet_loss + duration_predictor_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_p_l = duration_predictor_loss.item() with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l)+"\n") with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l)+"\n") with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l)+"\n") with open(os.path.join("logger", "duration_predictor_loss.txt"), "a") as f_d_p_loss: f_d_p_loss.write(str(d_p_l)+"\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.clock() str1 = "Epoch [{}/{}], Step [{}/{}], Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f};".format( epoch+1, hp.epochs, current_step, total_step, mel_loss.item(), mel_postnet_loss.item()) str2 = "Duration Predictor Loss: {:.4f}, Total Loss: {:.4f}.".format( duration_predictor_loss.item(), total_loss.item()) str3 = "Current Learning Rate is {:.6f}.".format( scheduled_optim.get_learning_rate()) str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now-Start), (total_step-current_step)*np.mean(Time)) print("\n" + str1) print(str2) print(str3) print(str4) with open(os.path.join("logger", "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") summary.add_scalar('loss', total_loss.item(), current_step) if current_step % hp.save_step == 0: torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict( )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) print("save model at step %d ..." % current_step) end_time = time.clock() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete( Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # read params mean_mel, std_mel = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) mean_f0, std_f0 = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) mean_energy, std_energy = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1) mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1) mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape( 1, -1) # Load vocoder if hp.vocoder == 'vocgan': vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path) vocoder.to(device) else: vocoder = None # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step, vocoder) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define model print("Use Tacotron2") model = nn.DataParallel(Tacotron2(hp)).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of TTS Parameters:', num_param) # Get buffer print("Load data to buffer") buffer = get_data_to_buffer() # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_rnn_dim, hp.n_warm_up_step, args.restore_step) tts_loss = DNNLoss().to(device) print("Defined Optimizer and Loss Function.") # Load checkpoint if exists try: checkpoint = torch.load( os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step %d---\n" % args.restore_step) except: print("\n---Start New Training---\n") if not os.path.exists(hp.checkpoint_path): os.mkdir(hp.checkpoint_path) # Init logger if not os.path.exists(hp.logger_path): os.mkdir(hp.logger_path) # Get dataset dataset = BufferDataset(buffer) # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_expand_size * hp.batch_size, shuffle=True, collate_fn=collate_fn_tensor, drop_last=True, num_workers=0) total_step = hp.epochs * len(training_loader) * hp.batch_expand_size # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): for i, batchs in enumerate(training_loader): # real batch start here for j, db in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_expand_size + j + args.restore_step + \ epoch * len(training_loader) * hp.batch_expand_size + 1 # Init scheduled_optim.zero_grad() # Get Data character = db["text"].long().to(device) mel_target = db["mel_target"].float().to(device) mel_pos = db["mel_pos"].long().to(device) src_pos = db["src_pos"].long().to(device) max_mel_len = db["mel_max_len"] mel_target = mel_target.contiguous().transpose(1, 2) src_length = torch.max(src_pos, -1)[0] mel_length = torch.max(mel_pos, -1)[0] gate_target = mel_pos.eq(0).float() gate_target = gate_target[:, 1:] gate_target = F.pad(gate_target, (0, 1, 0, 0), value=1.) # Forward inputs = character, src_length, mel_target, max_mel_len, mel_length mel_output, mel_output_postnet, gate_output = model(inputs) # Cal Loss mel_loss, mel_postnet_loss, gate_loss \ = tts_loss(mel_output, mel_output_postnet, gate_output, mel_target, gate_target) total_loss = mel_loss + mel_postnet_loss + gate_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() g_l = gate_loss.item() with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join("logger", "gate_loss.txt"), "a") as f_g_loss: f_g_loss.write(str(g_l) + "\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:"\ .format(epoch + 1, hp.epochs, current_step, total_step) str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Gate Loss: {:.4f};".format( m_l, m_p_l, g_l) str3 = "Current Learning Rate is {:.6f}."\ .format(scheduled_optim.get_learning_rate()) str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s."\ .format((Now-Start), (total_step-current_step) * np.mean(Time, dtype=np.float32)) print("\n" + str1) print(str2) print(str3) print(str4) with open(os.path.join("logger", "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) print("save model at step %d ..." % current_step) end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define model model = nn.DataParallel(FastSpeech()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech Parameters:', num_param) # Get dataset dataset = FastSpeechDataset() # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step, args.restore_step) fastspeech_loss = FastSpeechLoss().to(device) print("Defined Optimizer and Loss Function.") # Load checkpoint if exists try: checkpoint = torch.load( os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step %d---\n" % args.restore_step) except: print("\n---Start New Training---\n") if not os.path.exists(hp.checkpoint_path): os.mkdir(hp.checkpoint_path) # Init logger if not os.path.exists(hp.logger_path): os.mkdir(hp.logger_path) # Define Some Information Time = np.array([]) Start = time.clock() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=0) total_step = hp.epochs * len(training_loader) * hp.batch_size for i, batchs in enumerate(training_loader): for j, data_of_batch in enumerate(batchs): start_time = time.clock() current_step = i * hp.batch_size + j + args.restore_step + \ epoch * len(training_loader)*hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data character = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) max_mel_len = data_of_batch["mel_max_len"] # Forward mel_output, mel_postnet_output, duration_predictor_output = model( character, src_pos, mel_pos=mel_pos, mel_max_length=max_mel_len, length_target=D) # print(mel_target.size()) # print(mel_output.size()) # Cal Loss mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss( mel_output, mel_postnet_output, duration_predictor_output, mel_target, D) total_loss = mel_loss + mel_postnet_loss + duration_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = duration_loss.item() with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join("logger", "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.clock() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format( m_l, m_p_l, d_l) str3 = "Current Learning Rate is {:.6f}.".format( scheduled_optim.get_learning_rate()) str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) print(str4) with open(os.path.join("logger", "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) print("save model at step %d ..." % current_step) end_time = time.clock() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
# Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion # nn.utils.clip_grad_norm_( # model.parameters(), hp.grad_clip_thresh) if current_step == hp.n_warm_up_step: optimizer.param_groups[0]['lr'] = 1e-4 if current_step < hp.n_warm_up_step: # Update weights scheduled_optim.step_and_update_lr() #optimizer.step() scheduled_optim.zero_grad() else: optimizer.step() #optimizer.step() optimizer.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) print(str1) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}".format( t_l, m_l, m_p_l, d_l) print(str2)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(STYLER()).to(device) print("Model Has Been Defined") # Parameters num_param = utils.get_param_num(model) text_encoder = utils.get_param_num( model.module.style_modeling.style_encoder.text_encoder) audio_encoder = utils.get_param_num( model.module.style_modeling.style_encoder.audio_encoder) predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\ + utils.get_param_num(model.module.style_modeling.pitch_predictor)\ + utils.get_param_num(model.module.style_modeling.energy_predictor) decoder = utils.get_param_num(model.module.decoder) print('Number of Model Parameters :', num_param) print('Number of Text Encoder Parameters :', text_encoder) print('Number of Audio Encoder Parameters :', audio_encoder) print('Number of Predictor Parameters :', predictors) print('Number of Decoder Parameters :', decoder) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = STYLERLoss().to(device) DATLoss = DomainAdversarialTrainingLoss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path()) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder vocoder = utils.get_vocoder() # Init logger log_path = hp.log_path() if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path() if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i*hp.batch_size + j + args.restore_step + \ epoch*len(loader)*hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) mel_aug = torch.from_numpy( data_of_batch["mel_aug"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) f0_norm = torch.from_numpy( data_of_batch["f0_norm"]).float().to(device) f0_norm_aug = torch.from_numpy( data_of_batch["f0_norm_aug"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) energy_input = torch.from_numpy( data_of_batch["energy_input"]).float().to(device) energy_input_aug = torch.from_numpy( data_of_batch["energy_input_aug"]).float().to(device) speaker_embed = torch.from_numpy( data_of_batch["speaker_embed"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model( text, mel_target, mel_aug, f0_norm, energy_input, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, speaker_embed=speaker_embed) # Cal Loss Clean mel_output, mel_postnet_output = mel_outputs[ 0], mel_postnet_outputs[0] mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\ aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device)) # Cal Loss Noisy mel_output_noisy, mel_postnet_output_noisy = mel_outputs[ 1], mel_postnet_outputs[1] mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss( mel_output_noisy, mel_postnet_output_noisy, mel_aug, ~mel_mask) # Forward DAT enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat( mel_aug, f0_norm_aug, energy_input_aug, mel_aug) duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder( enc_cat, mel_len, src_len, mask=None) aug_posterior_d = model.module.style_modeling.augmentation_classifier_d( duration_encoding) aug_posterior_p = model.module.style_modeling.augmentation_classifier_p( pitch_encoding) aug_posterior_e = model.module.style_modeling.augmentation_classifier_e( energy_encoding) # Cal Loss DAT classifier_loss_a_dat = DATLoss( (aug_posterior_d, aug_posterior_p, aug_posterior_e), torch.ones(mel_target.size(0)).long().to(device)) # Total loss total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\ + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat) # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() m_n_l = mel_noisy_loss.item() m_p_n_l = mel_postnet_noisy_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() cl_a = classifier_loss_a.item() cl_a_dat = classifier_loss_a_dat.item() # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step == 1 or current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l, current_step) train_logger.add_scalar('Loss/mel_postnet_noisy_loss', m_p_n_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) train_logger.add_scalar('Loss/dat_clean_loss', cl_a, current_step) train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step == 1 or current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_aug_torch = mel_aug[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_aug = mel_aug[0, :length].detach().cpu().transpose( 0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel_noisy_torch = mel_output_noisy[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_noisy = mel_output_noisy[ 0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet_noisy_torch = mel_postnet_output_noisy[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) mel_postnet_noisy = mel_postnet_output_noisy[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join( # synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c"))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join( # synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c"))) # Audio.tools.inv_mel_spec(mel_noisy, os.path.join( # synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n"))) # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join( # synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n"))) wav_mel = utils.vocoder_infer( mel_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_{}.wav'.format(current_step, "c", hp.vocoder))) wav_mel_postnet = utils.vocoder_infer( mel_postnet_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_postnet_{}.wav'.format( current_step, "c", hp.vocoder))) wav_ground_truth = utils.vocoder_infer( mel_target_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_ground-truth_{}.wav'.format( current_step, "c", hp.vocoder))) wav_mel_noisy = utils.vocoder_infer( mel_noisy_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_{}.wav'.format(current_step, "n", hp.vocoder))) wav_mel_postnet_noisy = utils.vocoder_infer( mel_postnet_noisy_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_postnet_{}.wav'.format( current_step, "n", hp.vocoder))) wav_aug = utils.vocoder_infer( mel_aug_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_ground-truth_{}.wav'.format( current_step, "n", hp.vocoder))) # Model duration prediction log_duration_output = log_duration_output[ 0, :src_len[0].item()].detach().cpu() # [seg_len] log_duration_output = torch.clamp(torch.round( torch.exp(log_duration_output) - hp.log_offset), min=0).int() model_duration = utils.get_alignment_2D( log_duration_output).T # [seg_len, mel_len] model_duration = utils.plot_alignment([model_duration]) # Model mel prediction f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() mel_predicted = utils.plot_data( [(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)], [ 'Synthetized Spectrogram Clean', 'Ground-Truth Spectrogram' ], filename=os.path.join( synth_path, 'step_{}_{}.png'.format(current_step, "c"))) mel_noisy_predicted = utils.plot_data( [(mel_postnet_noisy.numpy(), f0_output, energy_output), (mel_aug.numpy(), f0, energy)], ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'], filename=os.path.join( synth_path, 'step_{}_{}.png'.format(current_step, "n"))) # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045 wav_ground_truth = wav_ground_truth / max(wav_ground_truth) wav_mel = wav_mel / max(wav_mel) wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet) wav_aug = wav_aug / max(wav_aug) wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy) wav_mel_postnet_noisy = wav_mel_postnet_noisy / max( wav_mel_postnet_noisy) train_logger.add_image("model_duration", model_duration, current_step, dataformats='HWC') train_logger.add_image("mel_predicted/Clean", mel_predicted, current_step, dataformats='HWC') train_logger.add_image("mel_predicted/Noisy", mel_noisy_predicted, current_step, dataformats='HWC') train_logger.add_audio("Clean/wav_ground_truth", wav_ground_truth, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Clean/wav_mel", wav_mel, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Clean/wav_mel_postnet", wav_mel_postnet, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_aug", wav_aug, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_mel_noisy", wav_mel_noisy, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_mel_postnet_noisy", wav_mel_postnet_noisy, current_step, sample_rate=hp.sampling_rate) if current_step == 1 or current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\ + hp.dat_weight*(cl_a + cl_a_dat) val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l, current_step) val_logger.add_scalar('Loss/mel_postnet_noisy_loss', m_p_n_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) val_logger.add_scalar('Loss/dat_clean_loss', cl_a, current_step) val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder wave_glow = utils.get_WaveGlow() # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) logger = SummaryWriter(log_path) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_len = max(data_of_batch["mel_len"]).astype(np.int16) # Forward mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model( text, src_pos, mel_pos, max_len, D, f0, energy) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") logger.add_scalars('Loss/total_loss', {'training': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'training': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'training': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'training': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'training': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'training': e_l}, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) Audio.tools.inv_mel_spec( mel, os.path.join( synth_path, "step_{}_griffin_lim.wav".format(current_step))) Audio.tools.inv_mel_spec( mel_postnet, os.path.join( synth_path, "step_{}_postnet_griffin_lim.wav".format( current_step))) waveglow.inference.inference( mel_torch, wave_glow, os.path.join( synth_path, "step_{}_waveglow.wav".format(current_step))) waveglow.inference.inference( mel_postnet_torch, wave_glow, os.path.join( synth_path, "step_{}_postnet_waveglow.wav".format( current_step))) waveglow.inference.inference( mel_target_torch, wave_glow, os.path.join( synth_path, "step_{}_ground-truth_waveglow.wav".format( current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l logger.add_scalars('Loss/total_loss', {'validation': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'validation': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'validation': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'validation': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'validation': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'validation': e_l}, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): # Get device # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') device = 'cuda' # Define model model = FastSpeech().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech Parameters:', num_param) current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime()) writer = SummaryWriter(log_dir='log/' + current_time) optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9) # Load checkpoint if exists try: checkpoint_in = open( os.path.join(hp.checkpoint_path, 'checkpoint.txt'), 'r') args.restore_step = int(checkpoint_in.readline().strip()) checkpoint_in.close() checkpoint = torch.load( os.path.join(hp.checkpoint_path, 'checkpoint_%08d.pth' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step %d---\n" % args.restore_step) except: print("\n---Start New Training---\n") if not os.path.exists(hp.checkpoint_path): os.mkdir(hp.checkpoint_path) # Get dataset dataset = FastSpeechDataset() # Optimizer and loss scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step, args.restore_step) fastspeech_loss = FastSpeechLoss().to(device) print("Defined Optimizer and Loss Function.") # Init logger if not os.path.exists(hp.logger_path): os.mkdir(hp.logger_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() t_l = 0.0 for epoch in range(hp.epochs): # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=0) total_step = hp.epochs * len(training_loader) * hp.batch_size for i, batchs in enumerate(training_loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + \ epoch * len(training_loader)*hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data condition1 = torch.from_numpy( data_of_batch["condition1"]).long().to(device) condition2 = torch.from_numpy( data_of_batch["condition2"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).long().to(device) norm_f0 = torch.from_numpy( data_of_batch["norm_f0"]).long().to(device) mel_in = torch.from_numpy( data_of_batch["mel_in"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) lens = data_of_batch["lens"] max_mel_len = data_of_batch["mel_max_len"] # print(condition1,condition2) # Forward mel_output = model(src_seq1=condition1, src_seq2=condition2, mel_in=mel_in, src_pos=src_pos, mel_pos=mel_pos, mel_max_length=max_mel_len, length_target=D) # print(mel_target.size()) # print(mel_output) # print(mel_postnet_output) # Cal Loss # mel_loss, mel_postnet_loss= fastspeech_loss(mel_output, mel_postnet_output,mel_target,) # print(mel_output.shape,mel_target.shape) Loss = torch.nn.CrossEntropyLoss() predict = mel_output.transpose(1, 2) target1 = mel_target.long().squeeze() target2 = norm_f0.long().squeeze() target = ((target1 + target2) / 2).long().squeeze() # print(predict.shape,target.shape) # print(target.float().mean()) losses = [] # print(lens,target) for index in range(predict.shape[0]): # print(predict[i,:,:lens[i]].shape,target[i,:lens[i]].shape) losses.append( Loss(predict[index, :, :lens[index]].transpose(0, 1), target[index, :lens[index]]).unsqueeze(0)) # losses.append(0.5*Loss(predict[index,:,:lens[index]].transpose(0,1),target2[index,:lens[index]]).unsqueeze(0)) total_loss = torch.cat(losses).mean() t_l += total_loss.item() # assert np.isnan(t_l)==False with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") # Backward if not np.isnan(t_l): total_loss.backward() else: print(condition1, condition2, D) # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch[{}/{}] Step[{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Loss:{:.4f} ".format(t_l / hp.log_step) str3 = "LR:{:.6f}".format( scheduled_optim.get_learning_rate()) str4 = "T: {:.1f}s ETR:{:.1f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print('\r' + str1 + ' ' + str2 + ' ' + str3 + ' ' + str4, end='') writer.add_scalar('loss', t_l / hp.log_step, current_step) writer.add_scalar('lreaning rate', scheduled_optim.get_learning_rate(), current_step) if hp.gpu_log_step != -1 and current_step % hp.gpu_log_step == 0: os.system('nvidia-smi') with open(os.path.join("logger", "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") t_l = 0.0 if current_step % hp.fig_step == 0 or current_step == 20: f = plt.figure() plt.matshow(mel_output[0].cpu().detach().numpy()) plt.savefig('out_predicted.png') plt.matshow( F.softmax(predict, dim=1).transpose( 1, 2)[0].cpu().detach().numpy()) plt.savefig('out_predicted_softmax.png') writer.add_figure('predict', f, current_step) plt.cla() f = plt.figure(figsize=(8, 6)) # plt.matshow(mel_target[0].cpu().detach().numpy()) # x=np.arange(mel_target.shape[1]) # y=sample_from_discretized_mix_logistic(mel_output.transpose(1,2)).cpu().detach().numpy()[0] # plt.plot(x,y) sample = [] p = F.softmax(predict, dim=1).transpose( 1, 2)[0].detach().cpu().numpy() for index in range(p.shape[0]): sample.append(np.random.choice(200, 1, p=p[index])) sample = np.array(sample) plt.plot(np.arange(sample.shape[0]), sample, color='grey', linewidth='1') for index in range(D.shape[1]): x = np.arange(D[0][index].cpu().numpy() ) + D[0][:index].cpu().numpy().sum() y = np.arange(D[0][index].detach().cpu().numpy()) if condition2[0][index].cpu().numpy() != 0: y.fill( (condition2[0][index].cpu().numpy() - 40.0) * 5) plt.plot(x, y, color='blue') plt.plot(np.arange(target.shape[1]), target[0].squeeze().detach().cpu().numpy(), color='red', linewidth='1') plt.savefig('out_target.png', dpi=300) writer.add_figure('target', f, current_step) plt.cla() plt.close("all") if current_step % (hp.save_step) == 0: print("save model at step %d ..." % current_step, end='') torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(hp.checkpoint_path, 'checkpoint_%08d.pth' % current_step)) checkpoint_out = open( os.path.join(hp.checkpoint_path, 'checkpoint.txt'), 'w') checkpoint_out.write(str(current_step)) checkpoint_out.close() # os.system('python savefig.py') print('save completed') end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)